|
|
@@ -306,15 +306,22 @@ class _TrainerMixin(abc.ABC):
|
|
|
|
|
|
self.evaluator.default_name = default_name
|
|
|
|
|
|
- def run(self, trainer_cls, opts, *args, **kwargs):
|
|
|
-
|
|
|
- trainer = trainer_cls(
|
|
|
+ def _new_trainer(self, trainer_cls, *args, **kwargs):
|
|
|
+ return trainer_cls(
|
|
|
opts=opts,
|
|
|
updater=self.updater,
|
|
|
evaluator=self.evaluator,
|
|
|
*args, **kwargs
|
|
|
)
|
|
|
|
|
|
+ def evaluate(self, trainer_cls, opts, *args, *kwargs):
|
|
|
+ trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
|
|
|
+ return self.evaluator.evaluate(trainer)
|
|
|
+
|
|
|
+ def run(self, trainer_cls, opts, *args, **kwargs):
|
|
|
+
|
|
|
+ trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
|
|
|
+
|
|
|
self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
|
|
|
|
|
|
logging.info("Snapshotting is {}abled".format("dis" if opts.no_snapshot else "en"))
|