|
|
@@ -1,16 +1,8 @@
|
|
|
-from cvfinetune.training.trainer.base import default_intervals, Trainer
|
|
|
-from chainer_addons.training.extensions import SacredReport
|
|
|
+from cvfinetune.training.trainer.base import Trainer, default_intervals
|
|
|
+from chainer_addons.training.sacred import SacredTrainerMixin
|
|
|
|
|
|
-class SacredTrainer(Trainer):
|
|
|
- def __init__(self, opts, sacred_params, intervals=default_intervals, *args, **kwargs):
|
|
|
- super(SacredTrainer, self).__init__(opts=opts, intervals=intervals, *args, **kwargs)
|
|
|
- self.sacred_reporter = SacredReport(opts, sacred_params=sacred_params, trigger=intervals.log)
|
|
|
- self.extend(self.sacred_reporter)
|
|
|
+class SacredTrainer(SacredTrainerMixin, Trainer):
|
|
|
|
|
|
- def _run(*args, **kwargs):
|
|
|
- return super(SacredTrainer, self).run(*args, **kwargs)
|
|
|
-
|
|
|
- self.sacred_reporter.ex.main(_run)
|
|
|
-
|
|
|
- def run(self, *args, **kwargs):
|
|
|
- return self.sacred_reporter.ex.run(*args, **kwargs)
|
|
|
+ def __init__(self, intervals=default_intervals, *args, **kwargs):
|
|
|
+ super(SacredTrainer, self).__init__(
|
|
|
+ intervals=intervals, sacred_trigger=intervals.log, *args, **kwargs)
|