Browse Source

fixed AlphaPoolingTrainer class; some minor changes in Finetuner initialization

Dimitri Korsch 3 năm trước cách đây
mục cha
commit
45c0415baf

+ 9 - 9
cvfinetune/parser/base.py

@@ -20,6 +20,15 @@ def default_factory(extra_list=[]):
 
 class FineTuneParser(GPUParser):
 
+	def __init__(self, *args, model_modules=None, **kwargs):
+		super(FineTuneParser, self).__init__(*args, **kwargs)
+		self._file_handler = None
+
+		add_dataset_args(self)
+		add_model_args(self, model_modules=model_modules)
+		add_training_args(self)
+
+
 	def _logging_config(self, simple=False):
 		if not self.has_logging: return
 		fmt = '{levelname:s} - [{asctime:s}] {filename:s}:{lineno:d} [{funcName:s}]: {message:s}'
@@ -49,15 +58,6 @@ class FineTuneParser(GPUParser):
 			warnings.warn("Could not flush logs to file: {}".format(e))
 
 
-	def __init__(self, *args, **kwargs):
-		super(FineTuneParser, self).__init__(*args, **kwargs)
-		self._file_handler = None
-
-		add_dataset_args(self)
-		add_model_args(self)
-		add_training_args(self)
-
-
 class HostnameFilter(logging.Filter):
 
 	def filter(self, record):

+ 7 - 2
cvfinetune/parser/model_args.py

@@ -1,4 +1,5 @@
 import abc
+import typing as T
 
 from chainer_addons.links import PoolingType
 from chainer_addons.models import PrepareType
@@ -8,9 +9,13 @@ from cvfinetune.parser.utils import parser_extender
 from cvmodelz.models import ModelFactory
 
 @parser_extender
-def add_model_args(parser: BaseParser) -> None:
+def add_model_args(parser: BaseParser, *,
+	model_modules: T.Optional[T.List[str]] = None) -> None:
 
-	choices = ModelFactory.get_models(["chainercv2", "cvmodelz"])
+	if model_modules is None:
+		model_modules = ["chainercv2", "cvmodelz"]
+
+	choices = ModelFactory.get_models(model_modules)
 	_args = [
 		Arg("--model_type", "-mt",
 			required=True,

+ 2 - 2
cvfinetune/parser/utils.py

@@ -23,11 +23,11 @@ def get_info_file():
 def parser_extender(extender):
 
 	@wraps(extender)
-	def inner(parser):
+	def inner(parser, *args, **kwargs):
 		assert isinstance(parser, BaseParser), \
 			"Parser should be an BaseParser instance!"
 
-		extender(parser)
+		extender(parser, *args, **kwargs)
 
 		return parser
 

+ 7 - 8
cvfinetune/training/trainer/alpha_pooling.py

@@ -1,20 +1,19 @@
-from .sacred import SacredTrainer
-from .base import _is_adam, default_intervals
-
 from chainer.training import extensions
 
+from cvfinetune.training.trainer import base
+
 def observe_alpha(trainer):
 	model = trainer.updater.get_optimizer("main").target.model
 	return float(model.pool.alpha.array)
 
-class AlphaPoolingTrainer(SacredTrainer):
+class AlphaPoolingTrainer(base.Trainer):
 
 	@property
 	def model(self):
 		return self.updater.get_optimizer("main").target.model
 
 	def __init__(self, opts, updater, *args, **kwargs):
-		super(AlphaPoolingTrainer, self).__init__(opts=opts, updater=updater, *args, **kwargs)
+		super().__init__(opts=opts, updater=updater, *args, **kwargs)
 		### Alternating training of CNN and FC layers (only for alpha-pooling) ###
 		if opts.switch_epochs:
 			self.extend(SwitchTrainables(
@@ -23,9 +22,9 @@ class AlphaPoolingTrainer(SacredTrainer):
 				pooling=self.model.pool))
 
 	def reportables(self, opts):
-		print_values, plot_values = super(AlphaPoolingTrainer, self).reportables()
+		print_values, plot_values = super().reportables()
 		alpha_update_rule = self.model.pool.alpha.update_rule
-		if _is_adam(opts):
+		if base._is_adam(opts):
 			# in case of Adam optimizer
 			alpha_update_rule.hyperparam.alpha *= opts.kappa
 		else:
@@ -33,7 +32,7 @@ class AlphaPoolingTrainer(SacredTrainer):
 
 		self.extend(
 			extensions.observe_value("alpha", observe_alpha),
-			trigger=default_intervals.print)
+			trigger=base.default_intervals.print)
 
 		print_values.append("alpha")
 		plot_values["alpha"]= ["alpha"]