Просмотр исходного кода

passing opts argument to reportables method of the trainer

Dimitri Korsch 5 лет назад
Родитель
Сommit
4c4a5ff74d

+ 1 - 1
cvfinetune/training/trainer/alpha_pooling.py

@@ -22,7 +22,7 @@ class AlphaPoolingTrainer(SacredTrainer):
 				model=self.model,
 				pooling=self.model.pool))
 
-	def reportables(self):
+	def reportables(self, opts):
 		print_values, plot_values = super(AlphaPoolingTrainer, self).reportables()
 		alpha_update_rule = self.model.pool.alpha.update_rule
 		if _is_adam(opts):

+ 2 - 2
cvfinetune/training/trainer/base.py

@@ -103,7 +103,7 @@ class Trainer(T):
 		self.setup_snapshots(opts, clf.model, intervals.snapshot)
 
 		### Reports and Plots ###
-		print_values, plot_values = self.reportables()
+		print_values, plot_values = self.reportables(opts)
 		self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
 		for name, values in plot_values.items():
 			ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
@@ -128,7 +128,7 @@ class Trainer(T):
 
 		return f"{self.evaluator.default_name}/{name}"
 
-	def reportables(self):
+	def reportables(self, opts):
 
 		print_values = [
 			"elapsed_time",

+ 1 - 1
examples/fve_example/core/trainer.py

@@ -3,7 +3,7 @@ from cvfinetune.training.trainer import Trainer
 
 class PartsTrainer(Trainer):
 
-	def reportables(self):
+	def reportables(self, opts):
 		print_vals, plot_vals = super(PartsTrainer, self).reportables()