|
@@ -0,0 +1,236 @@
|
|
|
+import logging
|
|
|
+from os.path import join, basename
|
|
|
+from datetime import datetime
|
|
|
+
|
|
|
+import chainer
|
|
|
+from chainer.training import extensions, Trainer as T
|
|
|
+from chainer_addons.training import lr_shift
|
|
|
+from chainer_addons.training.optimizer import OptimizerType
|
|
|
+from chainer_addons.training.extensions import SacredReport
|
|
|
+from chainer_addons.training.extensions.learning_rate import CosineAnnealingLearningRate
|
|
|
+from chainer_addons.training.extensions import AlternateTrainable, SwitchTrainables, WarmUp
|
|
|
+
|
|
|
+from cvdatasets.utils import attr_dict
|
|
|
+
|
|
|
+def debug_hook(trainer):
|
|
|
+ pass
|
|
|
+ # print(trainer.updater.get_optimizer("main").target.model.fc6.W.data.mean(), file=open("debug.out", "a"))
|
|
|
+
|
|
|
+default_intervals = attr_dict(
|
|
|
+ print = (1, 'epoch'),
|
|
|
+ log = (1, 'epoch'),
|
|
|
+ eval = (1, 'epoch'),
|
|
|
+ snapshot = (10, 'epoch'),
|
|
|
+)
|
|
|
+
|
|
|
+def observe_alpha(trainer):
|
|
|
+ model = trainer.updater.get_optimizer("main").target.model
|
|
|
+ return float(model.pool.alpha.array)
|
|
|
+
|
|
|
+def _is_adam(opts):
|
|
|
+ return opts.optimizer == OptimizerType.ADAM.name.lower()
|
|
|
+
|
|
|
+class Trainer(T):
|
|
|
+ _default_base_model = "model"
|
|
|
+
|
|
|
+ def __init__(self, opts, updater, evaluator=None, weights=None, intervals=default_intervals, no_observe=False):
|
|
|
+
|
|
|
+ self._only_eval = opts.only_eval
|
|
|
+ if weights is None or weights == "auto":
|
|
|
+ self.base_model = self._default_base_model
|
|
|
+ else:
|
|
|
+ self.base_model, _, _ = basename(weights).rpartition(".")
|
|
|
+
|
|
|
+ optimizer = updater.get_optimizer("main")
|
|
|
+ # adam has some specific attributes, so we need to check this
|
|
|
+ is_adam = _is_adam(opts)
|
|
|
+ clf = optimizer.target
|
|
|
+ model = clf.model
|
|
|
+
|
|
|
+ outdir = self.output_directory(opts)
|
|
|
+ logging.info("Training outputs are saved under \"{}\"".format(outdir))
|
|
|
+
|
|
|
+ super(Trainer, self).__init__(
|
|
|
+ updater=updater,
|
|
|
+ stop_trigger=(opts.epochs, 'epoch'),
|
|
|
+ out=outdir
|
|
|
+ )
|
|
|
+
|
|
|
+ ### Evaluator ###
|
|
|
+ if evaluator is not None:
|
|
|
+ self.extend(evaluator, trigger=intervals.eval)
|
|
|
+
|
|
|
+ ### Warm up ###
|
|
|
+ lr_offset = 0
|
|
|
+ if opts.warm_up:
|
|
|
+ assert opts.warm_up > 0, "Warm-up argument must be positive!"
|
|
|
+ lr_offset = opts.warm_up
|
|
|
+
|
|
|
+ warm_up_lr = opts.learning_rate
|
|
|
+ logging.info("Warm-up of {} epochs enabled!".format(opts.warm_up))
|
|
|
+ self.extend(WarmUp(
|
|
|
+ opts.warm_up, model,
|
|
|
+ opts.learning_rate, warm_up_lr))
|
|
|
+
|
|
|
+
|
|
|
+ ### LR shift ###
|
|
|
+ if opts.cosine_schedule:
|
|
|
+ lr_shift_ext = CosineAnnealingLearningRate(
|
|
|
+ attr="alpha" if is_adam else "lr",
|
|
|
+ lr=opts.learning_rate,
|
|
|
+ target=opts.lr_target,
|
|
|
+ epochs=opts.epochs,
|
|
|
+ offset=lr_offset
|
|
|
+ )
|
|
|
+ self.extend(lr_shift_ext)
|
|
|
+ else:
|
|
|
+ lr_shift_ext = lr_shift(optimizer,
|
|
|
+ init=opts.learning_rate,
|
|
|
+ rate=opts.lr_decrease_rate, target=opts.lr_target)
|
|
|
+ self.extend(lr_shift_ext, trigger=(opts.lr_shift, 'epoch'))
|
|
|
+
|
|
|
+ ### Code below is only for "main" Trainers ###
|
|
|
+ if no_observe: return
|
|
|
+
|
|
|
+ self.extend(extensions.observe_lr(), trigger=intervals.log)
|
|
|
+ self.extend(extensions.LogReport(trigger=intervals.log))
|
|
|
+
|
|
|
+ ### Snapshotting ###
|
|
|
+ self.setup_snapshots(opts, clf.model, intervals.snapshot)
|
|
|
+
|
|
|
+ ### Reports and Plots ###
|
|
|
+ print_values, plot_values = self.reportables(opts, model, evaluator)
|
|
|
+ self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
|
|
|
+ for name, values in plot_values.items():
|
|
|
+ ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
|
|
|
+ self.extend(ext)
|
|
|
+
|
|
|
+ ### Progress bar ###
|
|
|
+ if not opts.no_progress:
|
|
|
+ self.extend(extensions.ProgressBar(update_interval=1))
|
|
|
+
|
|
|
+ def setup_snapshots(self, opts, obj, trigger):
|
|
|
+
|
|
|
+ if opts.no_snapshot:
|
|
|
+ logging.warning("Models are not snapshot!")
|
|
|
+ else:
|
|
|
+ dump_fmt = "ft_model_epoch{0.updater.epoch:03d}.npz"
|
|
|
+ self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
|
|
|
+ logging.info("Snapshot format: \"{}\"".format(dump_fmt))
|
|
|
+
|
|
|
+
|
|
|
+ def reportables(self, opts, model, evaluator):
|
|
|
+ eval_name = lambda name: f"{evaluator.default_name}/{name}"
|
|
|
+
|
|
|
+
|
|
|
+ print_values = [
|
|
|
+ "elapsed_time",
|
|
|
+ "epoch",
|
|
|
+ # "lr",
|
|
|
+
|
|
|
+ "main/accuracy", eval_name("main/accuracy"),
|
|
|
+ "main/loss", eval_name("main/loss"),
|
|
|
+
|
|
|
+ ]
|
|
|
+
|
|
|
+ plot_values = {
|
|
|
+ "accuracy": [
|
|
|
+ "main/accuracy", eval_name("main/accuracy"),
|
|
|
+ ],
|
|
|
+ "loss": [
|
|
|
+ "main/loss", eval_name("main/loss"),
|
|
|
+ ],
|
|
|
+ }
|
|
|
+
|
|
|
+ # if opts.triplet_loss:
|
|
|
+ # print_values.extend(["main/t_loss", eval_name("main/t_loss")])
|
|
|
+ # plot_values.update({
|
|
|
+ # "t_loss": [
|
|
|
+ # "main/t_loss", eval_name("main/t_loss"),
|
|
|
+ # ]
|
|
|
+ # })
|
|
|
+
|
|
|
+ # if opts.use_parts:
|
|
|
+ # print_values.extend(["main/logL", eval_name("main/logL")])
|
|
|
+ # plot_values.update({
|
|
|
+ # "logL": [
|
|
|
+ # "main/logL", eval_name("main/logL"),
|
|
|
+ # ]
|
|
|
+ # })
|
|
|
+
|
|
|
+ # if not opts.no_global:
|
|
|
+ # print_values.extend([
|
|
|
+ # "main/glob_accu", eval_name("main/glob_accu"),
|
|
|
+ # # "main/glob_loss", eval_name("main/glob_loss"),
|
|
|
+
|
|
|
+ # "main/part_accu", eval_name("main/part_accu"),
|
|
|
+ # # "main/part_loss", eval_name("main/part_loss"),
|
|
|
+ # ])
|
|
|
+
|
|
|
+ # plot_values["accuracy"].extend([
|
|
|
+ # "main/part_accu", eval_name("main/part_accu"),
|
|
|
+ # "main/glob_accu", eval_name("main/glob_accu"),
|
|
|
+ # ])
|
|
|
+
|
|
|
+ # plot_values["loss"].extend([
|
|
|
+ # "main/part_loss", eval_name("main/part_loss"),
|
|
|
+ # "main/glob_loss", eval_name("main/glob_loss"),
|
|
|
+ # ])
|
|
|
+
|
|
|
+
|
|
|
+ return print_values, plot_values
|
|
|
+
|
|
|
+
|
|
|
+ def output_directory(self, opts):
|
|
|
+
|
|
|
+ result = opts.output
|
|
|
+
|
|
|
+ if self.base_model != self._default_base_model:
|
|
|
+ result = join(result, self.base_model)
|
|
|
+
|
|
|
+ result = join(result, datetime.now().strftime("%Y-%m-%d-%H.%M.%S"))
|
|
|
+ return result
|
|
|
+
|
|
|
+ def run(self, init_eval=True):
|
|
|
+ if init_eval:
|
|
|
+ logging.info("Evaluating initial model ...")
|
|
|
+ evaluator = self.get_extension("val")
|
|
|
+ init_perf = evaluator(self)
|
|
|
+ logging.info("Initial accuracy: {val/main/accuracy:.3%} initial loss: {val/main/loss:.3f}".format(
|
|
|
+ **{key: float(value) for key, value in init_perf.items()}
|
|
|
+ ))
|
|
|
+ if self._only_eval:
|
|
|
+ return
|
|
|
+ return super(Trainer, self).run()
|
|
|
+
|
|
|
+class SacredTrainer(Trainer):
|
|
|
+ def __init__(self, ex, *args, **kwargs):
|
|
|
+ super(SacredTrainer, self).__init__(*args, **kwargs)
|
|
|
+ self.extend(SacredReport(ex=ex, trigger=intervals.log))
|
|
|
+
|
|
|
+class AlphaPoolingTrainer(SacredTrainer):
|
|
|
+
|
|
|
+ def __init__(self, opts, updater, *args, **kwargs):
|
|
|
+ super(AlphaPoolingTrainer, self).__init__(opts=opts, updater=updater, *args, **kwargs)
|
|
|
+ model = updater.get_optimizer("main").target.model
|
|
|
+ ### Alternating training of CNN and FC layers (only for alpha-pooling) ###
|
|
|
+ if opts.switch_epochs:
|
|
|
+ self.extend(SwitchTrainables(
|
|
|
+ opts.switch_epochs,
|
|
|
+ model=model,
|
|
|
+ pooling=model.pool))
|
|
|
+
|
|
|
+ def reportables(self, opts, model, evaluator):
|
|
|
+ print_values, plot_values = super(AlphaPoolingTrainer, self).reportables(opts, model, evaluator)
|
|
|
+ alpha_update_rule = model.pool.alpha.update_rule
|
|
|
+ if _is_adam(opts):
|
|
|
+ # in case of Adam optimizer
|
|
|
+ alpha_update_rule.hyperparam.alpha *= opts.kappa
|
|
|
+ else:
|
|
|
+ alpha_update_rule.hyperparam.lr *= opts.kappa
|
|
|
+
|
|
|
+ self.extend(extensions.observe_value("alpha", observe_alpha), trigger=intervals.print)
|
|
|
+ print_values.append("alpha")
|
|
|
+ plot_values["alpha"]= ["alpha"]
|
|
|
+
|
|
|
+ return print_values, plot_values
|