|
@@ -1,6 +1,7 @@
|
|
|
import logging
|
|
import logging
|
|
|
from os.path import join, basename
|
|
from os.path import join, basename
|
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
|
|
|
+from typing import Tuple
|
|
|
|
|
|
|
|
import chainer
|
|
import chainer
|
|
|
from chainer.training import extensions, Trainer as T
|
|
from chainer.training import extensions, Trainer as T
|
|
@@ -33,6 +34,8 @@ class Trainer(T):
|
|
|
def __init__(self, opts, updater, evaluator=None, weights=None, intervals=default_intervals, no_observe=False):
|
|
def __init__(self, opts, updater, evaluator=None, weights=None, intervals=default_intervals, no_observe=False):
|
|
|
|
|
|
|
|
self._only_eval = opts.only_eval
|
|
self._only_eval = opts.only_eval
|
|
|
|
|
+ self.offset = 0
|
|
|
|
|
+
|
|
|
if weights is None or weights == "auto":
|
|
if weights is None or weights == "auto":
|
|
|
self.base_model = self._default_base_model
|
|
self.base_model = self._default_base_model
|
|
|
else:
|
|
else:
|
|
@@ -56,42 +59,25 @@ class Trainer(T):
|
|
|
out=outdir
|
|
out=outdir
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- ### Evaluator ###
|
|
|
|
|
- self.evaluator = evaluator
|
|
|
|
|
- if evaluator is not None:
|
|
|
|
|
- self.extend(evaluator, trigger=intervals.eval)
|
|
|
|
|
-
|
|
|
|
|
- ### Warm up ###
|
|
|
|
|
- lr_offset = 0
|
|
|
|
|
- if opts.warm_up:
|
|
|
|
|
- assert opts.warm_up > 0, "Warm-up argument must be positive!"
|
|
|
|
|
- lr_offset = opts.warm_up
|
|
|
|
|
|
|
+ self.setup_evaluator(evaluator, intervals.eval)
|
|
|
|
|
|
|
|
- warm_up_lr = opts.learning_rate
|
|
|
|
|
- logging.info("Warm-up of {} epochs enabled!".format(opts.warm_up))
|
|
|
|
|
- self.extend(WarmUp(
|
|
|
|
|
- opts.warm_up, model,
|
|
|
|
|
- opts.learning_rate, warm_up_lr))
|
|
|
|
|
|
|
+ self.setup_warm_up(model,
|
|
|
|
|
+ epochs=opts.warm_up,
|
|
|
|
|
+ after_warm_up_lr=opts.learning_rate,
|
|
|
|
|
+ warm_up_lr=opts.learning_rate
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
+ self.setup_lr_schedule(optimizer,
|
|
|
|
|
+ lr=opts.learning_rate,
|
|
|
|
|
+ lr_target=opts.lr_target,
|
|
|
|
|
+ lr_shift_trigger=(opts.lr_shift, "epochs"),
|
|
|
|
|
+ lr_decrease_rate=opts.lr_decrease_rate,
|
|
|
|
|
|
|
|
- ### LR shift ###
|
|
|
|
|
- if opts.cosine_schedule is not None and opts.cosine_schedule > 0:
|
|
|
|
|
- lr_shift_ext = CosineAnnealingLearningRate(
|
|
|
|
|
- attr="alpha" if is_adam else "lr",
|
|
|
|
|
- lr=opts.learning_rate,
|
|
|
|
|
- target=opts.lr_target,
|
|
|
|
|
- epochs=opts.epochs,
|
|
|
|
|
- offset=lr_offset,
|
|
|
|
|
- stages=opts.cosine_schedule
|
|
|
|
|
- )
|
|
|
|
|
- new_epochs = lr_shift_ext._epochs
|
|
|
|
|
- self.stop_trigger = trigger_module.get_trigger((new_epochs, "epoch"))
|
|
|
|
|
- self.extend(lr_shift_ext)
|
|
|
|
|
- else:
|
|
|
|
|
- lr_shift_ext = lr_shift(optimizer,
|
|
|
|
|
- init=opts.learning_rate,
|
|
|
|
|
- rate=opts.lr_decrease_rate, target=opts.lr_target)
|
|
|
|
|
- self.extend(lr_shift_ext, trigger=(opts.lr_shift, 'epoch'))
|
|
|
|
|
|
|
+ # needed for cosine annealing
|
|
|
|
|
+ epochs=opts.epochs,
|
|
|
|
|
+ cosine_schedule=opts.cosine_schedule,
|
|
|
|
|
+ attr="alpha" if _is_adam(opts) else "lr",
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
### Code below is only for "main" Trainers ###
|
|
### Code below is only for "main" Trainers ###
|
|
|
if no_observe: return
|
|
if no_observe: return
|
|
@@ -113,6 +99,58 @@ class Trainer(T):
|
|
|
if not opts.no_progress:
|
|
if not opts.no_progress:
|
|
|
self.extend(extensions.ProgressBar(update_interval=1))
|
|
self.extend(extensions.ProgressBar(update_interval=1))
|
|
|
|
|
|
|
|
|
|
+ def setup_lr_schedule(self, optimizer: chainer.Optimizer,
|
|
|
|
|
+ lr: float,
|
|
|
|
|
+ lr_target: float,
|
|
|
|
|
+ lr_shift_trigger: Tuple[int, str],
|
|
|
|
|
+ lr_decrease_rate: float,
|
|
|
|
|
+
|
|
|
|
|
+ epochs: int,
|
|
|
|
|
+ cosine_schedule: int,
|
|
|
|
|
+ attr: str,
|
|
|
|
|
+ ):
|
|
|
|
|
+ if cosine_schedule is not None and cosine_schedule > 0:
|
|
|
|
|
+ lr_shift_ext = CosineAnnealingLearningRate(
|
|
|
|
|
+ attr=attr,
|
|
|
|
|
+ lr=lr,
|
|
|
|
|
+ target=lr_target,
|
|
|
|
|
+ epochs=epochs,
|
|
|
|
|
+ offset=self.offset,
|
|
|
|
|
+ stages=cosine_schedule
|
|
|
|
|
+ )
|
|
|
|
|
+ new_epochs = lr_shift_ext._epochs
|
|
|
|
|
+ self.stop_trigger = trigger_module.get_trigger((new_epochs, "epoch"))
|
|
|
|
|
+ logging.info(f"Changed number of training epochs from {epochs} to {new_epochs}")
|
|
|
|
|
+ lr_shift_trigger = None
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ lr_shift_ext = lr_shift(optimizer,
|
|
|
|
|
+ init=lr, rate=lr_decrease_rate, target=lr_target)
|
|
|
|
|
+
|
|
|
|
|
+ self.extend(lr_shift_ext, trigger=lr_shift_trigger)
|
|
|
|
|
+
|
|
|
|
|
+ def setup_warm_up(self, model, epochs: int, after_warm_up_lr: float, warm_up_lr: float):
|
|
|
|
|
+
|
|
|
|
|
+ if epochs == 0:
|
|
|
|
|
+ return
|
|
|
|
|
+ assert epochs > 0, "Warm-up argument must be positive!"
|
|
|
|
|
+ self.offset = epochs
|
|
|
|
|
+
|
|
|
|
|
+ logging.info(f"Warm-up of {epochs} epochs enabled!")
|
|
|
|
|
+
|
|
|
|
|
+ self.extend(WarmUp(epochs, model,
|
|
|
|
|
+ initial_lr=after_warm_up_lr,
|
|
|
|
|
+ warm_up_lr=warm_up_lr
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def setup_evaluator(self, evaluator: chainer.training.Evaluator, trigger):
|
|
|
|
|
+ self.evaluator = evaluator
|
|
|
|
|
+ if evaluator is None:
|
|
|
|
|
+ return
|
|
|
|
|
+ self.extend(evaluator, trigger=intervals.eval)
|
|
|
|
|
+
|
|
|
def setup_snapshots(self, opts, obj, trigger):
|
|
def setup_snapshots(self, opts, obj, trigger):
|
|
|
|
|
|
|
|
if opts.no_snapshot:
|
|
if opts.no_snapshot:
|