Ver código fonte

refactored trainer module. added headless weight loading option

Dimitri Korsch 6 anos atrás
pai
commit
b8403319a6

+ 7 - 4
cvfinetune/finetuner/base.py

@@ -129,8 +129,8 @@ class _ModelMixin(abc.ABC):
 		else:
 			if args.load:
 				self.weights = args.load
-				logging.info("Loading already fine-tuned weights from \"{}\"".format(self.weights))
-				loader = partial(self.model.load_for_inference, weights=self.weights)
+				msg = "Loading already fine-tuned weights from \"{}\""
+				loader_func = self.model.load_for_inference
 			else:
 				self.weights = join(
 					self.data_info.BASE_DIR,
@@ -138,8 +138,11 @@ class _ModelMixin(abc.ABC):
 					self.model_info.folder,
 					self.model_info.weights
 				)
-				logging.info("Loading pre-trained weights \"{}\"".format(self.weights))
-				loader = partial(self.model.load_for_finetune, weights=self.weights)
+				msg = "Loading pre-trained weights \"{}\""
+				loader_func = self.model.load_for_finetune
+
+			logging.info(msg.format(self.weights))
+			loader = partial(loader_func, weights=self.weights, headless=args.headless)
 
 		if hasattr(self.clf, "output_size"):
 			feat_size = self.clf.output_size

+ 1 - 0
cvfinetune/parser.py

@@ -34,6 +34,7 @@ def default_factory(extra_list=[]):
 				help_text="type of pre-classification pooling"),
 
 			Arg("--load", type=str, help="ignore weights and load already fine-tuned model"),
+			Arg("--headless", action="store_true", help="ignores classifier layer during loading"),
 
 			Arg("--n_jobs", "-j", type=int, default=0,
 				help="number of loading processes. If 0, then images are loaded in the same process"),

+ 3 - 0
cvfinetune/training/trainer/__init__.py

@@ -0,0 +1,3 @@
+from .base import default_intervals, Trainer
+from .sacred import SacredTrainer
+from .alpha_pooling import AlphaPoolingTrainer

+ 41 - 0
cvfinetune/training/trainer/alpha_pooling.py

@@ -0,0 +1,41 @@
+from .sacred import SacredReport
+from .base import _is_adam, default_intervals
+
+from chainer.training import extensions
+
+def observe_alpha(trainer):
+	model = trainer.updater.get_optimizer("main").target.model
+	return float(model.pool.alpha.array)
+
+class AlphaPoolingTrainer(SacredTrainer):
+
+	@property
+	def model(self):
+		return self.updater.get_optimizer("main").target.model
+
+	def __init__(self, opts, updater, *args, **kwargs):
+		super(AlphaPoolingTrainer, self).__init__(opts=opts, updater=updater, *args, **kwargs)
+		### Alternating training of CNN and FC layers (only for alpha-pooling) ###
+		if opts.switch_epochs:
+			self.extend(SwitchTrainables(
+				opts.switch_epochs,
+				model=self.model,
+				pooling=self.model.pool))
+
+	def reportables(self):
+		print_values, plot_values = super(AlphaPoolingTrainer, self).reportables()
+		alpha_update_rule = self.model.pool.alpha.update_rule
+		if _is_adam(opts):
+			# in case of Adam optimizer
+			alpha_update_rule.hyperparam.alpha *= opts.kappa
+		else:
+			alpha_update_rule.hyperparam.lr *= opts.kappa
+
+		self.extend(
+			extensions.observe_value("alpha", observe_alpha),
+			trigger=default_intervals.print)
+
+		print_values.append("alpha")
+		plot_values["alpha"]= ["alpha"]
+
+		return print_values, plot_values

+ 4 - 42
cvfinetune/training/trainer.py → cvfinetune/training/trainer/base.py

@@ -7,16 +7,11 @@ from chainer.training import extensions, Trainer as T
 from chainer.training import trigger as trigger_module
 from chainer_addons.training import lr_shift
 from chainer_addons.training.optimizer import OptimizerType
-from chainer_addons.training.extensions import SacredReport
 from chainer_addons.training.extensions.learning_rate import CosineAnnealingLearningRate
 from chainer_addons.training.extensions import AlternateTrainable, SwitchTrainables, WarmUp
 
 from cvdatasets.utils import attr_dict
 
-def debug_hook(trainer):
-	pass
-	# print(trainer.updater.get_optimizer("main").target.model.fc6.W.data.mean(), file=open("debug.out", "a"))
-
 default_intervals = attr_dict(
 	print =		(1,  'epoch'),
 	log =		(1,  'epoch'),
@@ -24,9 +19,10 @@ default_intervals = attr_dict(
 	snapshot =	(10, 'epoch'),
 )
 
-def observe_alpha(trainer):
-	model = trainer.updater.get_optimizer("main").target.model
-	return float(model.pool.alpha.array)
+def debug_hook(trainer):
+	pass
+	# print(trainer.updater.get_optimizer("main").target.model.fc6.W.data.mean(), file=open("debug.out", "a"))
+
 
 def _is_adam(opts):
 	return opts.optimizer == OptimizerType.ADAM.name.lower()
@@ -177,37 +173,3 @@ class Trainer(T):
 			return
 		return super(Trainer, self).run()
 
-class SacredTrainer(Trainer):
-	def __init__(self, ex, intervals=default_intervals, *args, **kwargs):
-		super(SacredTrainer, self).__init__(intervals=intervals, *args, **kwargs)
-		self.extend(SacredReport(ex=ex, trigger=intervals.log))
-
-class AlphaPoolingTrainer(SacredTrainer):
-
-	@property
-	def model(self):
-		return self.updater.get_optimizer("main").target.model
-
-	def __init__(self, opts, updater, *args, **kwargs):
-		super(AlphaPoolingTrainer, self).__init__(opts=opts, updater=updater, *args, **kwargs)
-		### Alternating training of CNN and FC layers (only for alpha-pooling) ###
-		if opts.switch_epochs:
-			self.extend(SwitchTrainables(
-				opts.switch_epochs,
-				model=self.model,
-				pooling=self.model.pool))
-
-	def reportables(self):
-		print_values, plot_values = super(AlphaPoolingTrainer, self).reportables()
-		alpha_update_rule = self.model.pool.alpha.update_rule
-		if _is_adam(opts):
-			# in case of Adam optimizer
-			alpha_update_rule.hyperparam.alpha *= opts.kappa
-		else:
-			alpha_update_rule.hyperparam.lr *= opts.kappa
-
-		self.extend(extensions.observe_value("alpha", observe_alpha), trigger=intervals.print)
-		print_values.append("alpha")
-		plot_values["alpha"]= ["alpha"]
-
-		return print_values, plot_values

+ 7 - 0
cvfinetune/training/trainer/sacred.py

@@ -0,0 +1,7 @@
+from .base import default_intervals, Trainer
+from chainer_addons.training.extensions import SacredReport
+
+class SacredTrainer(Trainer):
+	def __init__(self, ex, intervals=default_intervals, *args, **kwargs):
+		super(SacredTrainer, self).__init__(intervals=intervals, *args, **kwargs)
+		self.extend(SacredReport(ex=ex, trigger=intervals.log))