|
|
@@ -29,45 +29,30 @@ def _is_adam(opts):
|
|
|
return opts.optimizer == OptimizerType.ADAM.name.lower()
|
|
|
|
|
|
class Trainer(T):
|
|
|
- _default_base_model = "model"
|
|
|
|
|
|
- def __init__(self, opts, updater, evaluator=None, weights=None, intervals=default_intervals, no_observe=False):
|
|
|
-
|
|
|
- self._only_eval = opts.only_eval
|
|
|
- self.offset = 0
|
|
|
-
|
|
|
- if weights is None or weights == "auto":
|
|
|
- self.base_model = self._default_base_model
|
|
|
- else:
|
|
|
- self.base_model, _, _ = basename(weights).rpartition(".")
|
|
|
-
|
|
|
- optimizer = updater.get_optimizer("main")
|
|
|
- # adam has some specific attributes, so we need to check this
|
|
|
- is_adam = _is_adam(opts)
|
|
|
- clf = optimizer.target
|
|
|
- model = clf.model
|
|
|
-
|
|
|
- if no_observe:
|
|
|
- outdir = opts.output
|
|
|
- else:
|
|
|
- outdir = self.output_directory(opts)
|
|
|
- logging.info("Training outputs are saved under \"{}\"".format(outdir))
|
|
|
+ def __init__(self, opts, updater,
|
|
|
+ evaluator: extensions.Evaluator = None,
|
|
|
+ intervals: attr_dict = default_intervals,
|
|
|
+ no_observe: bool = False):
|
|
|
|
|
|
super(Trainer, self).__init__(
|
|
|
updater=updater,
|
|
|
stop_trigger=(opts.epochs, 'epoch'),
|
|
|
- out=outdir
|
|
|
+ out=opts.output
|
|
|
)
|
|
|
+ logging.info("Training outputs are saved under \"{}\"".format(self.out))
|
|
|
+
|
|
|
+ self._only_eval = opts.only_eval
|
|
|
+ self.offset = 0
|
|
|
|
|
|
self.setup_evaluator(evaluator, intervals.eval)
|
|
|
|
|
|
- self.setup_warm_up(model,
|
|
|
- epochs=opts.warm_up,
|
|
|
+ self.setup_warm_up(epochs=opts.warm_up,
|
|
|
after_warm_up_lr=opts.learning_rate,
|
|
|
warm_up_lr=opts.learning_rate
|
|
|
)
|
|
|
|
|
|
- self.setup_lr_schedule(optimizer,
|
|
|
+ self.setup_lr_schedule(
|
|
|
lr=opts.learning_rate,
|
|
|
lr_target=opts.lr_target,
|
|
|
lr_shift_trigger=(opts.lr_shift, "epochs"),
|
|
|
@@ -86,11 +71,12 @@ class Trainer(T):
|
|
|
self.extend(extensions.LogReport(trigger=intervals.log))
|
|
|
|
|
|
### Snapshotting ###
|
|
|
- self.setup_snapshots(opts, clf.model, intervals.snapshot)
|
|
|
+ self.setup_snapshots(opts, self.model, intervals.snapshot)
|
|
|
|
|
|
### Reports and Plots ###
|
|
|
print_values, plot_values = self.reportables(opts)
|
|
|
- self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
|
|
|
+ self.extend(extensions.PrintReport(print_values),
|
|
|
+ trigger=intervals.print)
|
|
|
for name, values in plot_values.items():
|
|
|
ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
|
|
|
self.extend(ext)
|
|
|
@@ -99,7 +85,19 @@ class Trainer(T):
|
|
|
if not opts.no_progress:
|
|
|
self.extend(extensions.ProgressBar(update_interval=1))
|
|
|
|
|
|
- def setup_lr_schedule(self, optimizer: chainer.Optimizer,
|
|
|
+ @property
|
|
|
+ def optimizer(self):
|
|
|
+ return self.updater.get_optimizer("main")
|
|
|
+
|
|
|
+ @property
|
|
|
+ def clf(self):
|
|
|
+ return self.optimizer.target
|
|
|
+
|
|
|
+ @property
|
|
|
+ def model(self):
|
|
|
+ return self.clf.model
|
|
|
+
|
|
|
+ def setup_lr_schedule(self,
|
|
|
lr: float,
|
|
|
lr_target: float,
|
|
|
lr_shift_trigger: Tuple[int, str],
|
|
|
@@ -107,8 +105,8 @@ class Trainer(T):
|
|
|
|
|
|
epochs: int,
|
|
|
cosine_schedule: int,
|
|
|
- attr: str,
|
|
|
- ):
|
|
|
+ attr: str):
|
|
|
+
|
|
|
if cosine_schedule is not None and cosine_schedule > 0:
|
|
|
lr_shift_ext = CosineAnnealingLearningRate(
|
|
|
attr=attr,
|
|
|
@@ -124,12 +122,12 @@ class Trainer(T):
|
|
|
lr_shift_trigger = None
|
|
|
|
|
|
else:
|
|
|
- lr_shift_ext = lr_shift(optimizer,
|
|
|
+ lr_shift_ext = lr_shift(self.optimizer,
|
|
|
init=lr, rate=lr_decrease_rate, target=lr_target)
|
|
|
|
|
|
self.extend(lr_shift_ext, trigger=lr_shift_trigger)
|
|
|
|
|
|
- def setup_warm_up(self, model, epochs: int, after_warm_up_lr: float, warm_up_lr: float):
|
|
|
+ def setup_warm_up(self, epochs: int, after_warm_up_lr: float, warm_up_lr: float):
|
|
|
|
|
|
if epochs == 0:
|
|
|
return
|
|
|
@@ -138,7 +136,7 @@ class Trainer(T):
|
|
|
|
|
|
logging.info(f"Warm-up of {epochs} epochs enabled!")
|
|
|
|
|
|
- self.extend(WarmUp(epochs, model,
|
|
|
+ self.extend(WarmUp(epochs, self.model,
|
|
|
initial_lr=after_warm_up_lr,
|
|
|
warm_up_lr=warm_up_lr
|
|
|
)
|
|
|
@@ -146,7 +144,7 @@ class Trainer(T):
|
|
|
|
|
|
|
|
|
def setup_evaluator(self,
|
|
|
- evaluator: chainer.training.extensions.Evaluator,
|
|
|
+ evaluator: extensions.Evaluator,
|
|
|
trigger: Tuple[int, str]):
|
|
|
|
|
|
self.evaluator = evaluator
|
|
|
@@ -193,16 +191,6 @@ class Trainer(T):
|
|
|
return print_values, plot_values
|
|
|
|
|
|
|
|
|
- def output_directory(self, opts):
|
|
|
-
|
|
|
- result = opts.output
|
|
|
-
|
|
|
- if self.base_model != self._default_base_model:
|
|
|
- result = join(result, self.base_model)
|
|
|
-
|
|
|
- result = join(result, datetime.now().strftime("%Y-%m-%d-%H.%M.%S.%f"))
|
|
|
- return result
|
|
|
-
|
|
|
def run(self, init_eval=True):
|
|
|
if init_eval:
|
|
|
logging.info("Evaluating initial model ...")
|