Просмотр исходного кода

refactored base trainer class: moved some init code to separate methods

Dimitri Korsch 4 лет назад
Родитель
Сommit
e545dfeddd
1 измененных файлов с 71 добавлено и 33 удалено
  1. 71 33
      cvfinetune/training/trainer/base.py

+ 71 - 33
cvfinetune/training/trainer/base.py

@@ -1,6 +1,7 @@
 import logging
 from os.path import join, basename
 from datetime import datetime
+from typing import Tuple
 
 import chainer
 from chainer.training import extensions, Trainer as T
@@ -33,6 +34,8 @@ class Trainer(T):
 	def __init__(self, opts, updater, evaluator=None, weights=None, intervals=default_intervals, no_observe=False):
 
 		self._only_eval = opts.only_eval
+		self.offset = 0
+
 		if weights is None or weights == "auto":
 			self.base_model = self._default_base_model
 		else:
@@ -56,42 +59,25 @@ class Trainer(T):
 			out=outdir
 		)
 
-		### Evaluator ###
-		self.evaluator = evaluator
-		if evaluator is not None:
-			self.extend(evaluator, trigger=intervals.eval)
-
-		### Warm up ###
-		lr_offset = 0
-		if opts.warm_up:
-			assert opts.warm_up > 0, "Warm-up argument must be positive!"
-			lr_offset = opts.warm_up
+		self.setup_evaluator(evaluator, intervals.eval)
 
-			warm_up_lr = opts.learning_rate
-			logging.info("Warm-up of {} epochs enabled!".format(opts.warm_up))
-			self.extend(WarmUp(
-				opts.warm_up, model,
-				opts.learning_rate, warm_up_lr))
+		self.setup_warm_up(model,
+			epochs=opts.warm_up,
+			after_warm_up_lr=opts.learning_rate,
+			warm_up_lr=opts.learning_rate
+		)
 
+		self.setup_lr_schedule(optimizer,
+			lr=opts.learning_rate,
+			lr_target=opts.lr_target,
+			lr_shift_trigger=(opts.lr_shift, "epochs"),
+			lr_decrease_rate=opts.lr_decrease_rate,
 
-		### LR shift ###
-		if opts.cosine_schedule is not None and opts.cosine_schedule > 0:
-			lr_shift_ext = CosineAnnealingLearningRate(
-				attr="alpha" if is_adam else "lr",
-				lr=opts.learning_rate,
-				target=opts.lr_target,
-				epochs=opts.epochs,
-				offset=lr_offset,
-				stages=opts.cosine_schedule
-			)
-			new_epochs = lr_shift_ext._epochs
-			self.stop_trigger = trigger_module.get_trigger((new_epochs, "epoch"))
-			self.extend(lr_shift_ext)
-		else:
-			lr_shift_ext = lr_shift(optimizer,
-				init=opts.learning_rate,
-				rate=opts.lr_decrease_rate, target=opts.lr_target)
-			self.extend(lr_shift_ext, trigger=(opts.lr_shift, 'epoch'))
+			# needed for cosine annealing
+			epochs=opts.epochs,
+			cosine_schedule=opts.cosine_schedule,
+			attr="alpha" if _is_adam(opts) else "lr",
+		)
 
 		### Code below is only for "main" Trainers ###
 		if no_observe: return
@@ -113,6 +99,58 @@ class Trainer(T):
 		if not opts.no_progress:
 			self.extend(extensions.ProgressBar(update_interval=1))
 
+	def setup_lr_schedule(self, optimizer: chainer.Optimizer,
+		lr: float,
+		lr_target: float,
+		lr_shift_trigger: Tuple[int, str],
+		lr_decrease_rate: float,
+
+		epochs: int,
+		cosine_schedule: int,
+		attr: str,
+	):
+		if cosine_schedule is not None and cosine_schedule > 0:
+			lr_shift_ext = CosineAnnealingLearningRate(
+				attr=attr,
+				lr=lr,
+				target=lr_target,
+				epochs=epochs,
+				offset=self.offset,
+				stages=cosine_schedule
+			)
+			new_epochs = lr_shift_ext._epochs
+			self.stop_trigger = trigger_module.get_trigger((new_epochs, "epoch"))
+			logging.info(f"Changed number of training epochs from {epochs} to {new_epochs}")
+			lr_shift_trigger = None
+
+		else:
+			lr_shift_ext = lr_shift(optimizer,
+				init=lr, rate=lr_decrease_rate, target=lr_target)
+
+		self.extend(lr_shift_ext, trigger=lr_shift_trigger)
+
+	def setup_warm_up(self, model, epochs: int, after_warm_up_lr: float, warm_up_lr: float):
+
+		if epochs == 0:
+			return
+		assert epochs > 0, "Warm-up argument must be positive!"
+		self.offset = epochs
+
+		logging.info(f"Warm-up of {epochs} epochs enabled!")
+
+		self.extend(WarmUp(epochs, model,
+			initial_lr=after_warm_up_lr,
+			warm_up_lr=warm_up_lr
+			)
+		)
+
+
+	def setup_evaluator(self, evaluator: chainer.training.Evaluator, trigger):
+		self.evaluator = evaluator
+		if evaluator is None:
+			return
+		self.extend(evaluator, trigger=intervals.eval)
+
 	def setup_snapshots(self, opts, obj, trigger):
 
 		if opts.no_snapshot: