base.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import logging
  2. from datetime import datetime
  3. from os.path import basename
  4. from os.path import join
  5. from typing import Tuple
  6. import chainer
  7. from chainer.training import Trainer as T
  8. from chainer.training import extensions
  9. from chainer.training import trigger as trigger_module
  10. from chainer_addons.training import lr_shift
  11. from chainer_addons.training.extensions import AlternateTrainable
  12. from chainer_addons.training.extensions import SwitchTrainables
  13. from chainer_addons.training.extensions import WarmUp
  14. from chainer_addons.training.extensions.learning_rate import CosineAnnealingLearningRate
  15. from chainer_addons.training.optimizer import OptimizerType
  16. from cvdatasets.utils import attr_dict
  17. default_intervals = attr_dict(
  18. print = (1, 'epoch'),
  19. log = (1, 'epoch'),
  20. eval = (1, 'epoch'),
  21. snapshot = (10, 'epoch'),
  22. )
  23. def debug_hook(trainer):
  24. pass
  25. # print(trainer.updater.get_optimizer("main").target.model.fc6.W.data.mean(), file=open("debug.out", "a"))
  26. def _is_adam(opts):
  27. return opts.optimizer == OptimizerType.ADAM.name.lower()
  28. class Trainer(T):
  29. def __init__(self, opts,
  30. updater,
  31. evaluator: extensions.Evaluator = None,
  32. intervals: attr_dict = default_intervals,
  33. no_observe: bool = False,
  34. **kwargs):
  35. super(Trainer, self).__init__(
  36. updater=updater,
  37. stop_trigger=(opts.epochs, 'epoch'),
  38. out=opts.output,
  39. **kwargs
  40. )
  41. logging.info("Training outputs are saved under \"{}\"".format(self.out))
  42. self._only_eval = opts.only_eval
  43. self.offset = 0
  44. self.setup_evaluator(evaluator, intervals.eval)
  45. self.setup_warm_up(epochs=opts.warm_up,
  46. after_warm_up_lr=opts.learning_rate,
  47. warm_up_lr=opts.learning_rate
  48. )
  49. self.setup_lr_schedule(
  50. lr=opts.learning_rate,
  51. lr_target=opts.lr_target,
  52. lr_shift_trigger=(opts.lr_shift, "epoch"),
  53. lr_decrease_rate=opts.lr_decrease_rate,
  54. # needed for cosine annealing
  55. epochs=opts.epochs,
  56. cosine_schedule=opts.cosine_schedule,
  57. attr="alpha" if _is_adam(opts) else "lr",
  58. )
  59. ### Code below is only for "main" Trainers ###
  60. if no_observe: return
  61. ### Snapshotting ###
  62. self.setup_snapshots(opts, self.model, intervals.snapshot)
  63. self.setup_reporter(opts, intervals.log, intervals.print)
  64. self.setup_progress_bar(opts)
  65. def setup_reporter(self, opts, log_trigger, print_trigger):
  66. self.extend(extensions.observe_lr(), trigger=log_trigger)
  67. self.extend(extensions.LogReport(trigger=log_trigger))
  68. ### Reports and Plots ###
  69. print_values, plot_values = self.reportables(opts)
  70. self.extend(extensions.PrintReport(print_values), trigger=print_trigger)
  71. for name, values in plot_values.items():
  72. ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
  73. self.extend(ext)
  74. def setup_progress_bar(self, opts):
  75. if not opts.no_progress:
  76. self.extend(extensions.ProgressBar(update_interval=1))
  77. elif self.evaluator is not None:
  78. self.evaluator._progress_bar = False
  79. @property
  80. def optimizer(self):
  81. return self.updater.get_optimizer("main")
  82. @property
  83. def clf(self):
  84. return self.optimizer.target
  85. @property
  86. def model(self):
  87. return self.clf.model
  88. def setup_lr_schedule(self,
  89. lr: float,
  90. lr_target: float,
  91. lr_shift_trigger: Tuple[int, str],
  92. lr_decrease_rate: float,
  93. epochs: int,
  94. cosine_schedule: int,
  95. attr: str):
  96. if cosine_schedule is not None and cosine_schedule > 0:
  97. lr_shift_ext = CosineAnnealingLearningRate(
  98. attr=attr,
  99. lr=lr,
  100. target=lr_target,
  101. epochs=epochs,
  102. offset=self.offset,
  103. stages=cosine_schedule
  104. )
  105. new_epochs = lr_shift_ext._epochs
  106. self.stop_trigger = trigger_module.get_trigger((new_epochs, "epoch"))
  107. logging.info(f"Changed number of training epochs from {epochs} to {new_epochs}")
  108. lr_shift_trigger = None
  109. else:
  110. lr_shift_ext = lr_shift(self.optimizer,
  111. init=lr, rate=lr_decrease_rate, target=lr_target)
  112. self.extend(lr_shift_ext, trigger=lr_shift_trigger)
  113. def setup_warm_up(self, epochs: int, after_warm_up_lr: float, warm_up_lr: float):
  114. if epochs is None or epochs == 0:
  115. return
  116. assert epochs > 0, "Warm-up argument must be positive!"
  117. self.offset = epochs
  118. logging.info(f"Warm-up of {epochs} epochs enabled!")
  119. self.extend(WarmUp(epochs, self.model,
  120. initial_lr=after_warm_up_lr,
  121. warm_up_lr=warm_up_lr
  122. )
  123. )
  124. def setup_evaluator(self,
  125. evaluator: extensions.Evaluator,
  126. trigger: Tuple[int, str]):
  127. self.evaluator = evaluator
  128. if evaluator is None:
  129. return
  130. self.extend(evaluator, trigger=trigger)
  131. def setup_snapshots(self, opts, obj, trigger):
  132. if opts.no_snapshot:
  133. logging.warning("Models are not snapshot!")
  134. else:
  135. dump_fmt = "ft_model_epoch{0.updater.epoch:03d}.npz"
  136. self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
  137. logging.info("Snapshot format: \"{}\"".format(dump_fmt))
  138. def eval_name(self, name):
  139. if self.evaluator is None:
  140. return name
  141. return f"{self.evaluator.default_name}/{name}"
  142. def reportables(self, opts):
  143. print_values = [
  144. "elapsed_time",
  145. "epoch",
  146. # "lr",
  147. "main/accuracy", self.eval_name("main/accuracy"),
  148. "main/loss", self.eval_name("main/loss"),
  149. ]
  150. plot_values = {
  151. "accuracy": [
  152. "main/accuracy", self.eval_name("main/accuracy"),
  153. ],
  154. "loss": [
  155. "main/loss", self.eval_name("main/loss"),
  156. ],
  157. }
  158. return print_values, plot_values
  159. def run(self, init_eval=True):
  160. if init_eval:
  161. logging.info("Evaluating initial model ...")
  162. init_perf = self.evaluator(self)
  163. values = {key: float(value) for key, value in init_perf.items()}
  164. msg = []
  165. if "val/main/accuracy" in values:
  166. msg.append("Initial accuracy: {val/main/accuracy:.3%}".format(**values))
  167. if "val/main/loss" in values:
  168. msg.append("Initial loss: {val/main/loss:.3f}".format(**values))
  169. logging.info(" ".join(msg))
  170. if self._only_eval:
  171. return
  172. return super(Trainer, self).run()