Răsfoiți Sursa

updated the SacredTrainer for the new SacredReporter

Dimitri Korsch 6 ani în urmă
părinte
comite
31b56a08a5
2 a modificat fișierele cu 14 adăugiri și 5 ștergeri
  1. 1 1
      cvfinetune/__init__.py
  2. 13 4
      cvfinetune/training/trainer/sacred.py

+ 1 - 1
cvfinetune/__init__.py

@@ -1 +1 @@
-__version__ = "0.2.5"
+__version__ = "0.2.6"

+ 13 - 4
cvfinetune/training/trainer/sacred.py

@@ -1,7 +1,16 @@
-from .base import default_intervals, Trainer
+from cvfinetune.training.trainer.base import default_intervals, Trainer
 from chainer_addons.training.extensions import SacredReport
 
 class SacredTrainer(Trainer):
-	def __init__(self, ex, intervals=default_intervals, *args, **kwargs):
-		super(SacredTrainer, self).__init__(intervals=intervals, *args, **kwargs)
-		self.extend(SacredReport(ex=ex, trigger=intervals.log))
+	def __init__(self, opts, sacred_params, intervals=default_intervals, *args, **kwargs):
+		super(SacredTrainer, self).__init__(opts=opts, intervals=intervals, *args, **kwargs)
+		self.sacred_reporter = SacredReport(opts, sacred_params=sacred_params, trigger=intervals.log)
+		self.extend(self.sacred_reporter)
+
+		def _run(*args, **kwargs):
+			return super(SacredTrainer, self).run(*args, **kwargs)
+
+		self.sacred_reporter.ex.main(_run)
+
+	def run(self, *args, **kwargs):
+		return self.sacred_reporter.ex.run(*args, **kwargs)