فهرست منبع

cleaned the trainer class

Dimitri Korsch 4 سال پیش
والد
کامیت
b04e852052
2فایلهای تغییر یافته به همراه33 افزوده شده و 46 حذف شده
  1. 0 1
      cvfinetune/finetuner/base.py
  2. 33 45
      cvfinetune/training/trainer/base.py

+ 0 - 1
cvfinetune/finetuner/base.py

@@ -312,7 +312,6 @@ class _TrainerMixin(abc.ABC):
 			opts=opts,
 			updater=self.updater,
 			evaluator=self.evaluator,
-			weights=self.weights,
 			*args, **kwargs
 		)
 

+ 33 - 45
cvfinetune/training/trainer/base.py

@@ -29,45 +29,30 @@ def _is_adam(opts):
 	return opts.optimizer == OptimizerType.ADAM.name.lower()
 
 class Trainer(T):
-	_default_base_model = "model"
 
-	def __init__(self, opts, updater, evaluator=None, weights=None, intervals=default_intervals, no_observe=False):
-
-		self._only_eval = opts.only_eval
-		self.offset = 0
-
-		if weights is None or weights == "auto":
-			self.base_model = self._default_base_model
-		else:
-			self.base_model, _, _ = basename(weights).rpartition(".")
-
-		optimizer = updater.get_optimizer("main")
-		# adam has some specific attributes, so we need to check this
-		is_adam = _is_adam(opts)
-		clf = optimizer.target
-		model = clf.model
-
-		if no_observe:
-			outdir = opts.output
-		else:
-			outdir = self.output_directory(opts)
-			logging.info("Training outputs are saved under \"{}\"".format(outdir))
+	def __init__(self, opts, updater,
+		evaluator: extensions.Evaluator = None,
+		intervals: attr_dict = default_intervals,
+		no_observe: bool = False):
 
 		super(Trainer, self).__init__(
 			updater=updater,
 			stop_trigger=(opts.epochs, 'epoch'),
-			out=outdir
+			out=opts.output
 		)
+		logging.info("Training outputs are saved under \"{}\"".format(self.out))
+
+		self._only_eval = opts.only_eval
+		self.offset = 0
 
 		self.setup_evaluator(evaluator, intervals.eval)
 
-		self.setup_warm_up(model,
-			epochs=opts.warm_up,
+		self.setup_warm_up(epochs=opts.warm_up,
 			after_warm_up_lr=opts.learning_rate,
 			warm_up_lr=opts.learning_rate
 		)
 
-		self.setup_lr_schedule(optimizer,
+		self.setup_lr_schedule(
 			lr=opts.learning_rate,
 			lr_target=opts.lr_target,
 			lr_shift_trigger=(opts.lr_shift, "epochs"),
@@ -86,11 +71,12 @@ class Trainer(T):
 		self.extend(extensions.LogReport(trigger=intervals.log))
 
 		### Snapshotting ###
-		self.setup_snapshots(opts, clf.model, intervals.snapshot)
+		self.setup_snapshots(opts, self.model, intervals.snapshot)
 
 		### Reports and Plots ###
 		print_values, plot_values = self.reportables(opts)
-		self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
+		self.extend(extensions.PrintReport(print_values),
+			trigger=intervals.print)
 		for name, values in plot_values.items():
 			ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
 			self.extend(ext)
@@ -99,7 +85,19 @@ class Trainer(T):
 		if not opts.no_progress:
 			self.extend(extensions.ProgressBar(update_interval=1))
 
-	def setup_lr_schedule(self, optimizer: chainer.Optimizer,
+	@property
+	def optimizer(self):
+		return self.updater.get_optimizer("main")
+
+	@property
+	def clf(self):
+		return self.optimizer.target
+
+	@property
+	def model(self):
+		return self.clf.model
+
+	def setup_lr_schedule(self,
 		lr: float,
 		lr_target: float,
 		lr_shift_trigger: Tuple[int, str],
@@ -107,8 +105,8 @@ class Trainer(T):
 
 		epochs: int,
 		cosine_schedule: int,
-		attr: str,
-	):
+		attr: str):
+
 		if cosine_schedule is not None and cosine_schedule > 0:
 			lr_shift_ext = CosineAnnealingLearningRate(
 				attr=attr,
@@ -124,12 +122,12 @@ class Trainer(T):
 			lr_shift_trigger = None
 
 		else:
-			lr_shift_ext = lr_shift(optimizer,
+			lr_shift_ext = lr_shift(self.optimizer,
 				init=lr, rate=lr_decrease_rate, target=lr_target)
 
 		self.extend(lr_shift_ext, trigger=lr_shift_trigger)
 
-	def setup_warm_up(self, model, epochs: int, after_warm_up_lr: float, warm_up_lr: float):
+	def setup_warm_up(self, epochs: int, after_warm_up_lr: float, warm_up_lr: float):
 
 		if epochs == 0:
 			return
@@ -138,7 +136,7 @@ class Trainer(T):
 
 		logging.info(f"Warm-up of {epochs} epochs enabled!")
 
-		self.extend(WarmUp(epochs, model,
+		self.extend(WarmUp(epochs, self.model,
 			initial_lr=after_warm_up_lr,
 			warm_up_lr=warm_up_lr
 			)
@@ -146,7 +144,7 @@ class Trainer(T):
 
 
 	def setup_evaluator(self,
-		evaluator: chainer.training.extensions.Evaluator,
+		evaluator: extensions.Evaluator,
 		trigger: Tuple[int, str]):
 
 		self.evaluator = evaluator
@@ -193,16 +191,6 @@ class Trainer(T):
 		return print_values, plot_values
 
 
-	def output_directory(self, opts):
-
-		result = opts.output
-
-		if self.base_model != self._default_base_model:
-			result = join(result, self.base_model)
-
-		result = join(result, datetime.now().strftime("%Y-%m-%d-%H.%M.%S.%f"))
-		return result
-
 	def run(self, init_eval=True):
 		if init_eval:
 			logging.info("Evaluating initial model ...")