Răsfoiți Sursa

updated requirements and moved sacred trainer logic to chainer_addons

Dimitri Korsch 6 ani în urmă
părinte
comite
3b252c567a
3 a modificat fișierele cu 8 adăugiri și 16 ștergeri
  1. 1 1
      cvfinetune/finetuner/base.py
  2. 6 14
      cvfinetune/training/trainer/sacred.py
  3. 1 1
      requirements.txt

+ 1 - 1
cvfinetune/finetuner/base.py

@@ -300,7 +300,7 @@ class _TrainerMixin(abc.ABC):
 				"model_{}.npz".format(suffix)), self.model)
 
 		try:
-			trainer.run(init_eval=opts.init_eval or opts.only_eval)
+			trainer.run(opts.init_eval or opts.only_eval)
 		except (KeyboardInterrupt, BdbQuit) as e:
 			raise e
 		except Exception as e:

+ 6 - 14
cvfinetune/training/trainer/sacred.py

@@ -1,16 +1,8 @@
-from cvfinetune.training.trainer.base import default_intervals, Trainer
-from chainer_addons.training.extensions import SacredReport
+from cvfinetune.training.trainer.base import Trainer, default_intervals
+from chainer_addons.training.sacred import SacredTrainerMixin
 
-class SacredTrainer(Trainer):
-	def __init__(self, opts, sacred_params, intervals=default_intervals, *args, **kwargs):
-		super(SacredTrainer, self).__init__(opts=opts, intervals=intervals, *args, **kwargs)
-		self.sacred_reporter = SacredReport(opts, sacred_params=sacred_params, trigger=intervals.log)
-		self.extend(self.sacred_reporter)
+class SacredTrainer(SacredTrainerMixin, Trainer):
 
-		def _run(*args, **kwargs):
-			return super(SacredTrainer, self).run(*args, **kwargs)
-
-		self.sacred_reporter.ex.main(_run)
-
-	def run(self, *args, **kwargs):
-		return self.sacred_reporter.ex.run(*args, **kwargs)
+	def __init__(self, intervals=default_intervals, *args, **kwargs):
+		super(SacredTrainer, self).__init__(
+			intervals=intervals, sacred_trigger=intervals.log, *args, **kwargs)

+ 1 - 1
requirements.txt

@@ -5,7 +5,7 @@ matplotlib~=3.0
 imageio~=2.3.0
 PyYAML~=5.1
 simplejson~=3.14
-sacred~=0.7.4
+sacred~=0.7
 
 chainer>=4.2.0,<7.0
 cupy-cuda100>=4.2.0,<7.0