Browse Source

added a PartsTrainer to the FVE Example

Dimitri Korsch 6 years ago
parent
commit
be1c931055

+ 0 - 1
examples/fve_example/core/classifier.py

@@ -95,7 +95,6 @@ class FVEClassifier(SeparateModelClassifier):
 		)
 		return glob_loss, glob_pred
 
-
 	@property
 	def feat_size(self):
 		return self.model.meta.feature_size

+ 2 - 0
examples/fve_example/core/dataset.py

@@ -9,6 +9,7 @@ from cvdatasets.dataset import IteratorMixin
 
 from finetune.dataset import _base_mixin
 
+
 class _parts_mixin(ABC):
 
 	def get_example(self, i):
@@ -18,6 +19,7 @@ class _parts_mixin(ABC):
 
 		return parts, im_obj.label + self.label_shift
 
+
 class PartsDataset(_base_mixin,
 	# augmentation and preprocessing
 	AugmentationMixin, PreprocessMixin,

+ 31 - 0
examples/fve_example/core/trainer.py

@@ -0,0 +1,31 @@
+from finetune.training.trainer import Trainer
+
+
+class PartsTrainer(Trainer):
+
+	def reportables(self):
+		print_vals, plot_vals = super(PartsTrainer, self).reportables()
+
+
+		print_vals.extend([
+			"main/glob_accu", self.eval_name("main/glob_accu"),
+			"main/part_accu", self.eval_name("main/part_accu"),
+
+			"main/logL", self.eval_name("main/logL"),
+			# "main/glob_loss", self.eval_name("main/glob_loss"),
+			# "main/part_loss", self.eval_name("main/part_loss"),
+		])
+
+		plot_vals["logL"] = ["main/logL", self.eval_name("main/logL")]
+
+		plot_vals["accuracy"].extend([
+			"main/part_accu", self.eval_name("main/part_accu"),
+			"main/glob_accu", self.eval_name("main/glob_accu"),
+		])
+
+		plot_vals["loss"].extend([
+			"main/part_loss", self.eval_name("main/part_loss"),
+			"main/glob_loss", self.eval_name("main/glob_loss"),
+		])
+
+		return print_vals, plot_vals

+ 2 - 6
examples/fve_example/main.py

@@ -12,13 +12,9 @@ import logging
 from chainer.training.updaters import StandardUpdater
 
 from finetune.finetuner import DefaultFinetuner
-from finetune.training.trainer import Trainer
-from finetune.dataset import BaseDataset
-from finetune.classifier import Classifier
-
 
 from utils import parser
-from core import classifier, dataset
+from core import classifier, dataset, trainer
 
 def main(args):
 	if args.debug:
@@ -46,7 +42,7 @@ def main(args):
 	)
 
 
-	tuner.run(trainer_cls=Trainer, opts=args)
+	tuner.run(trainer_cls=trainer.PartsTrainer, opts=args)
 
 
 main(parser.parse_args())

+ 21 - 43
finetune/training/trainer.py

@@ -57,6 +57,7 @@ class Trainer(T):
 		)
 
 		### Evaluator ###
+		self.evaluator = evaluator
 		if evaluator is not None:
 			self.extend(evaluator, trigger=intervals.eval)
 
@@ -99,7 +100,7 @@ class Trainer(T):
 		self.setup_snapshots(opts, clf.model, intervals.snapshot)
 
 		### Reports and Plots ###
-		print_values, plot_values = self.reportables(opts, model, evaluator)
+		print_values, plot_values = self.reportables()
 		self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
 		for name, values in plot_values.items():
 			ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
@@ -118,27 +119,30 @@ class Trainer(T):
 			self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
 			logging.info("Snapshot format: \"{}\"".format(dump_fmt))
 
+	def eval_name(self, name):
+		if self.evaluator is None:
+			return name
 
-	def reportables(self, opts, model, evaluator):
-		eval_name = lambda name: f"{evaluator.default_name}/{name}"
+		return f"{self.evaluator.default_name}/{name}"
 
+	def reportables(self):
 
 		print_values = [
 			"elapsed_time",
 			"epoch",
 			# "lr",
 
-			"main/accuracy", eval_name("main/accuracy"),
-			"main/loss", eval_name("main/loss"),
+			"main/accuracy", self.eval_name("main/accuracy"),
+			"main/loss", self.eval_name("main/loss"),
 
 		]
 
 		plot_values = {
 			"accuracy": [
-				"main/accuracy",  eval_name("main/accuracy"),
+				"main/accuracy",  self.eval_name("main/accuracy"),
 			],
 			"loss": [
-				"main/loss", eval_name("main/loss"),
+				"main/loss", self.eval_name("main/loss"),
 			],
 		}
 
@@ -150,34 +154,6 @@ class Trainer(T):
 		# 		]
 		# 	})
 
-		# if opts.use_parts:
-		# 	print_values.extend(["main/logL", eval_name("main/logL")])
-		# 	plot_values.update({
-		# 		"logL": [
-		# 			"main/logL", eval_name("main/logL"),
-		# 		]
-		# 	})
-
-		# 	if not opts.no_global:
-		# 		print_values.extend([
-		# 			"main/glob_accu", eval_name("main/glob_accu"),
-		# 			# "main/glob_loss", eval_name("main/glob_loss"),
-
-		# 			"main/part_accu", eval_name("main/part_accu"),
-		# 			# "main/part_loss", eval_name("main/part_loss"),
-		# 		])
-
-		# 		plot_values["accuracy"].extend([
-		# 			"main/part_accu", eval_name("main/part_accu"),
-		# 			"main/glob_accu", eval_name("main/glob_accu"),
-		# 		])
-
-		# 		plot_values["loss"].extend([
-		# 			"main/part_loss", eval_name("main/part_loss"),
-		# 			"main/glob_loss", eval_name("main/glob_loss"),
-		# 		])
-
-
 		return print_values, plot_values
 
 
@@ -194,8 +170,7 @@ class Trainer(T):
 	def run(self, init_eval=True):
 		if init_eval:
 			logging.info("Evaluating initial model ...")
-			evaluator = self.get_extension("val")
-			init_perf = evaluator(self)
+			init_perf = self.evaluator(self)
 			logging.info("Initial accuracy: {val/main/accuracy:.3%} initial loss: {val/main/loss:.3f}".format(
 				**{key: float(value) for key, value in init_perf.items()}
 			))
@@ -210,19 +185,22 @@ class SacredTrainer(Trainer):
 
 class AlphaPoolingTrainer(SacredTrainer):
 
+	@property
+	def model(self):
+		return self.updater.get_optimizer("main").target.model
+
 	def __init__(self, opts, updater, *args, **kwargs):
 		super(AlphaPoolingTrainer, self).__init__(opts=opts, updater=updater, *args, **kwargs)
-		model = updater.get_optimizer("main").target.model
 		### Alternating training of CNN and FC layers (only for alpha-pooling) ###
 		if opts.switch_epochs:
 			self.extend(SwitchTrainables(
 				opts.switch_epochs,
-				model=model,
-				pooling=model.pool))
+				model=self.model,
+				pooling=self.model.pool))
 
-	def reportables(self, opts, model, evaluator):
-		print_values, plot_values = super(AlphaPoolingTrainer, self).reportables(opts, model, evaluator)
-		alpha_update_rule = model.pool.alpha.update_rule
+	def reportables(self):
+		print_values, plot_values = super(AlphaPoolingTrainer, self).reportables()
+		alpha_update_rule = self.model.pool.alpha.update_rule
 		if _is_adam(opts):
 			# in case of Adam optimizer
 			alpha_update_rule.hyperparam.alpha *= opts.kappa