|
|
@@ -1,7 +1,16 @@
|
|
|
-from .base import default_intervals, Trainer
|
|
|
+from cvfinetune.training.trainer.base import default_intervals, Trainer
|
|
|
from chainer_addons.training.extensions import SacredReport
|
|
|
|
|
|
class SacredTrainer(Trainer):
|
|
|
- def __init__(self, ex, intervals=default_intervals, *args, **kwargs):
|
|
|
- super(SacredTrainer, self).__init__(intervals=intervals, *args, **kwargs)
|
|
|
- self.extend(SacredReport(ex=ex, trigger=intervals.log))
|
|
|
+ def __init__(self, opts, sacred_params, intervals=default_intervals, *args, **kwargs):
|
|
|
+ super(SacredTrainer, self).__init__(opts=opts, intervals=intervals, *args, **kwargs)
|
|
|
+ self.sacred_reporter = SacredReport(opts, sacred_params=sacred_params, trigger=intervals.log)
|
|
|
+ self.extend(self.sacred_reporter)
|
|
|
+
|
|
|
+ def _run(*args, **kwargs):
|
|
|
+ return super(SacredTrainer, self).run(*args, **kwargs)
|
|
|
+
|
|
|
+ self.sacred_reporter.ex.main(_run)
|
|
|
+
|
|
|
+ def run(self, *args, **kwargs):
|
|
|
+ return self.sacred_reporter.ex.run(*args, **kwargs)
|