trainer.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import logging
  2. from os.path import join, basename
  3. from datetime import datetime
  4. import chainer
  5. from chainer.training import extensions, Trainer as T
  6. from chainer_addons.training import lr_shift
  7. from chainer_addons.training.optimizer import OptimizerType
  8. from chainer_addons.training.extensions import SacredReport
  9. from chainer_addons.training.extensions.learning_rate import CosineAnnealingLearningRate
  10. from chainer_addons.training.extensions import AlternateTrainable, SwitchTrainables, WarmUp
  11. from cvdatasets.utils import attr_dict
  12. def debug_hook(trainer):
  13. pass
  14. # print(trainer.updater.get_optimizer("main").target.model.fc6.W.data.mean(), file=open("debug.out", "a"))
  15. default_intervals = attr_dict(
  16. print = (1, 'epoch'),
  17. log = (1, 'epoch'),
  18. eval = (1, 'epoch'),
  19. snapshot = (10, 'epoch'),
  20. )
  21. def observe_alpha(trainer):
  22. model = trainer.updater.get_optimizer("main").target.model
  23. return float(model.pool.alpha.array)
  24. def _is_adam(opts):
  25. return opts.optimizer == OptimizerType.ADAM.name.lower()
  26. class Trainer(T):
  27. _default_base_model = "model"
  28. def __init__(self, opts, updater, evaluator=None, weights=None, intervals=default_intervals, no_observe=False):
  29. self._only_eval = opts.only_eval
  30. if weights is None or weights == "auto":
  31. self.base_model = self._default_base_model
  32. else:
  33. self.base_model, _, _ = basename(weights).rpartition(".")
  34. optimizer = updater.get_optimizer("main")
  35. # adam has some specific attributes, so we need to check this
  36. is_adam = _is_adam(opts)
  37. clf = optimizer.target
  38. model = clf.model
  39. outdir = self.output_directory(opts)
  40. logging.info("Training outputs are saved under \"{}\"".format(outdir))
  41. super(Trainer, self).__init__(
  42. updater=updater,
  43. stop_trigger=(opts.epochs, 'epoch'),
  44. out=outdir
  45. )
  46. ### Evaluator ###
  47. if evaluator is not None:
  48. self.extend(evaluator, trigger=intervals.eval)
  49. ### Warm up ###
  50. lr_offset = 0
  51. if opts.warm_up:
  52. assert opts.warm_up > 0, "Warm-up argument must be positive!"
  53. lr_offset = opts.warm_up
  54. warm_up_lr = opts.learning_rate
  55. logging.info("Warm-up of {} epochs enabled!".format(opts.warm_up))
  56. self.extend(WarmUp(
  57. opts.warm_up, model,
  58. opts.learning_rate, warm_up_lr))
  59. ### LR shift ###
  60. if opts.cosine_schedule:
  61. lr_shift_ext = CosineAnnealingLearningRate(
  62. attr="alpha" if is_adam else "lr",
  63. lr=opts.learning_rate,
  64. target=opts.lr_target,
  65. epochs=opts.epochs,
  66. offset=lr_offset
  67. )
  68. self.extend(lr_shift_ext)
  69. else:
  70. lr_shift_ext = lr_shift(optimizer,
  71. init=opts.learning_rate,
  72. rate=opts.lr_decrease_rate, target=opts.lr_target)
  73. self.extend(lr_shift_ext, trigger=(opts.lr_shift, 'epoch'))
  74. ### Code below is only for "main" Trainers ###
  75. if no_observe: return
  76. self.extend(extensions.observe_lr(), trigger=intervals.log)
  77. self.extend(extensions.LogReport(trigger=intervals.log))
  78. ### Snapshotting ###
  79. self.setup_snapshots(opts, clf.model, intervals.snapshot)
  80. ### Reports and Plots ###
  81. print_values, plot_values = self.reportables(opts, model, evaluator)
  82. self.extend(extensions.PrintReport(print_values), trigger=intervals.print)
  83. for name, values in plot_values.items():
  84. ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
  85. self.extend(ext)
  86. ### Progress bar ###
  87. if not opts.no_progress:
  88. self.extend(extensions.ProgressBar(update_interval=1))
  89. def setup_snapshots(self, opts, obj, trigger):
  90. if opts.no_snapshot:
  91. logging.warning("Models are not snapshot!")
  92. else:
  93. dump_fmt = "ft_model_epoch{0.updater.epoch:03d}.npz"
  94. self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
  95. logging.info("Snapshot format: \"{}\"".format(dump_fmt))
  96. def reportables(self, opts, model, evaluator):
  97. eval_name = lambda name: f"{evaluator.default_name}/{name}"
  98. print_values = [
  99. "elapsed_time",
  100. "epoch",
  101. # "lr",
  102. "main/accuracy", eval_name("main/accuracy"),
  103. "main/loss", eval_name("main/loss"),
  104. ]
  105. plot_values = {
  106. "accuracy": [
  107. "main/accuracy", eval_name("main/accuracy"),
  108. ],
  109. "loss": [
  110. "main/loss", eval_name("main/loss"),
  111. ],
  112. }
  113. # if opts.triplet_loss:
  114. # print_values.extend(["main/t_loss", eval_name("main/t_loss")])
  115. # plot_values.update({
  116. # "t_loss": [
  117. # "main/t_loss", eval_name("main/t_loss"),
  118. # ]
  119. # })
  120. # if opts.use_parts:
  121. # print_values.extend(["main/logL", eval_name("main/logL")])
  122. # plot_values.update({
  123. # "logL": [
  124. # "main/logL", eval_name("main/logL"),
  125. # ]
  126. # })
  127. # if not opts.no_global:
  128. # print_values.extend([
  129. # "main/glob_accu", eval_name("main/glob_accu"),
  130. # # "main/glob_loss", eval_name("main/glob_loss"),
  131. # "main/part_accu", eval_name("main/part_accu"),
  132. # # "main/part_loss", eval_name("main/part_loss"),
  133. # ])
  134. # plot_values["accuracy"].extend([
  135. # "main/part_accu", eval_name("main/part_accu"),
  136. # "main/glob_accu", eval_name("main/glob_accu"),
  137. # ])
  138. # plot_values["loss"].extend([
  139. # "main/part_loss", eval_name("main/part_loss"),
  140. # "main/glob_loss", eval_name("main/glob_loss"),
  141. # ])
  142. return print_values, plot_values
  143. def output_directory(self, opts):
  144. result = opts.output
  145. if self.base_model != self._default_base_model:
  146. result = join(result, self.base_model)
  147. result = join(result, datetime.now().strftime("%Y-%m-%d-%H.%M.%S"))
  148. return result
  149. def run(self, init_eval=True):
  150. if init_eval:
  151. logging.info("Evaluating initial model ...")
  152. evaluator = self.get_extension("val")
  153. init_perf = evaluator(self)
  154. logging.info("Initial accuracy: {val/main/accuracy:.3%} initial loss: {val/main/loss:.3f}".format(
  155. **{key: float(value) for key, value in init_perf.items()}
  156. ))
  157. if self._only_eval:
  158. return
  159. return super(Trainer, self).run()
  160. class SacredTrainer(Trainer):
  161. def __init__(self, ex, *args, **kwargs):
  162. super(SacredTrainer, self).__init__(*args, **kwargs)
  163. self.extend(SacredReport(ex=ex, trigger=intervals.log))
  164. class AlphaPoolingTrainer(SacredTrainer):
  165. def __init__(self, opts, updater, *args, **kwargs):
  166. super(AlphaPoolingTrainer, self).__init__(opts=opts, updater=updater, *args, **kwargs)
  167. model = updater.get_optimizer("main").target.model
  168. ### Alternating training of CNN and FC layers (only for alpha-pooling) ###
  169. if opts.switch_epochs:
  170. self.extend(SwitchTrainables(
  171. opts.switch_epochs,
  172. model=model,
  173. pooling=model.pool))
  174. def reportables(self, opts, model, evaluator):
  175. print_values, plot_values = super(AlphaPoolingTrainer, self).reportables(opts, model, evaluator)
  176. alpha_update_rule = model.pool.alpha.update_rule
  177. if _is_adam(opts):
  178. # in case of Adam optimizer
  179. alpha_update_rule.hyperparam.alpha *= opts.kappa
  180. else:
  181. alpha_update_rule.hyperparam.lr *= opts.kappa
  182. self.extend(extensions.observe_value("alpha", observe_alpha), trigger=intervals.print)
  183. print_values.append("alpha")
  184. plot_values["alpha"]= ["alpha"]
  185. return print_values, plot_values