model.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import abc
  2. import chainer
  3. import logging
  4. from chainer import functions as F
  5. from chainer.optimizer_hooks import Lasso
  6. from chainer.optimizer_hooks import WeightDecay
  7. from chainer_addons.functions import smoothed_cross_entropy
  8. from chainer_addons.models import PrepareType
  9. from chainer_addons.training import optimizer
  10. from chainer_addons.training import optimizer_hooks
  11. from chainercv2.models import model_store
  12. from cvdatasets.dataset.image import Size
  13. from cvdatasets.utils import pretty_print_dict
  14. from cvmodelz.models import ModelFactory
  15. from functools import partial
  16. from pathlib import Path
  17. from typing import Tuple
  18. def check_param_for_decay(param):
  19. return param.name != "alpha"
  20. def enable_only_head(chain: chainer.Chain):
  21. if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
  22. chain.enable_only_head()
  23. else:
  24. chain.disable_update()
  25. chain.fc.enable_update()
  26. class _ModelMixin(abc.ABC):
  27. """
  28. This mixin is responsible for optimizer creation, model creation,
  29. model wrapping around a classifier and model weights loading.
  30. """
  31. def __init__(self, opts, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
  32. super(_ModelMixin, self).__init__(opts=opts, *args, **kwargs)
  33. self.classifier_cls = classifier_cls
  34. self.classifier_kwargs = classifier_kwargs
  35. self.model_type = opts.model_type
  36. self.model_kwargs = model_kwargs
  37. @property
  38. def model_info(self):
  39. return self.data_info.MODELS[self.model_type]
  40. def init_model(self, opts):
  41. """creates backbone CNN model. This model is wrapped around the classifier later"""
  42. self.model = ModelFactory.new(self.model_type,
  43. input_size=Size(opts.input_size),
  44. **self.model_kwargs
  45. )
  46. if self.model_type.startswith("chainercv2"):
  47. opts.prepare_type = "chainercv2"
  48. self.prepare = partial(PrepareType[opts.prepare_type](self.model),
  49. swap_channels=opts.swap_channels,
  50. keep_ratio=getattr(opts, "center_crop_on_val", False),
  51. )
  52. logging.info(
  53. f"Created {self.model.__class__.__name__} model "
  54. f" with \"{opts.prepare_type}\" prepare function."
  55. )
  56. def init_classifier(self, opts):
  57. clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
  58. self.clf = clf_class(
  59. model=self.model,
  60. loss_func=self._loss_func(opts),
  61. **kwargs)
  62. logging.info(
  63. f"Wrapped the model around {clf_class.__name__}"
  64. f" with kwargs: {pretty_print_dict(kwargs)}"
  65. )
  66. def _loss_func(self, opts):
  67. if getattr(opts, "l1_loss", False):
  68. return F.hinge
  69. label_smoothing = getattr(opts, "label_smoothing", 0)
  70. if label_smoothing > 0:
  71. assert label_smoothing < 1, "Label smoothing factor must be less than 1!"
  72. return partial(smoothed_cross_entropy, N=self.n_classes, eps=label_smoothing)
  73. return F.softmax_cross_entropy
  74. def init_optimizer(self, opts):
  75. """Creates an optimizer for the classifier """
  76. if not hasattr(opts, "optimizer"):
  77. self.opt = None
  78. return
  79. opt_kwargs = {}
  80. if opts.optimizer == "rmsprop":
  81. opt_kwargs["alpha"] = 0.9
  82. if opts.optimizer in ["rmsprop", "adam"]:
  83. opt_kwargs["eps"] = 1e-6
  84. self.opt = optimizer(opts.optimizer,
  85. self.clf,
  86. opts.learning_rate,
  87. decay=0, gradient_clipping=False, **opt_kwargs
  88. )
  89. logging.info(
  90. f"Initialized {self.opt.__class__.__name__} optimizer"
  91. f" with initial LR {opts.learning_rate} and kwargs: {pretty_print_dict(opt_kwargs)}"
  92. )
  93. if opts.decay > 0:
  94. reg_kwargs = {}
  95. if opts.l1_loss:
  96. reg_cls = Lasso
  97. elif opts.pooling == "alpha":
  98. reg_cls = optimizer_hooks.SelectiveWeightDecay
  99. reg_kwargs["selection"] = check_param_for_decay
  100. else:
  101. reg_cls = WeightDecay
  102. logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
  103. self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
  104. if getattr(opts, "only_head", False):
  105. assert not getattr(opts, "recurrent", False), \
  106. "Recurrent classifier is not supported with only_head option!"
  107. logging.warning("========= Fine-tuning only classifier layer! =========")
  108. enable_only_head(self.clf)
  109. def _get_loader(self, opts) -> Tuple[bool, str]:
  110. if getattr(opts, "from_scratch", False):
  111. logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
  112. return None, None
  113. if getattr(opts, "load", None):
  114. weights = getattr(opts, "load", None)
  115. logging.info(f"Loading already fine-tuned weights from \"{weights}\"")
  116. return False, weights
  117. elif getattr(opts, "weights", None):
  118. weights = getattr(opts, "weights", None)
  119. logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
  120. return True, weights
  121. else:
  122. weights = self._default_weights(opts)
  123. logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
  124. return True, weights
  125. def _default_weights(self, opts):
  126. if self.model_type.startswith("chainercv2"):
  127. model_name = self.model_type.split(".")[-1]
  128. return model_store.get_model_file(
  129. model_name=model_name,
  130. local_model_store_dir_path=str(Path.home() / ".chainer" / "models"))
  131. else:
  132. ds_info = self.data_info
  133. model_info = self.model_info
  134. base_dir = Path(ds_info.BASE_DIR)
  135. weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
  136. weights = model_info.weights
  137. assert opts.pre_training in weights, \
  138. f"Weights for \"{opts.pre_training}\" pre-training were not found!"
  139. return str(weights_dir / weights[opts.pre_training])
  140. def load_weights(self, opts) -> None:
  141. finetune, weights = self._get_loader(opts)
  142. self.clf.load(weights,
  143. n_classes=self.n_classes,
  144. finetune=finetune,
  145. path=opts.load_path,
  146. strict=opts.load_strict,
  147. headless=opts.headless
  148. )
  149. self.clf.cleargrads()
  150. feat_size = self.model.meta.feature_size
  151. if hasattr(self.clf, "output_size"):
  152. feat_size = self.clf.output_size
  153. ### TODO: handle feature size!
  154. logging.info(f"Part features size after encoding: {feat_size}")