optimizer.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import abc
  2. import chainer
  3. import logging
  4. from chainer.optimizer_hooks import Lasso
  5. from chainer.optimizer_hooks import WeightDecay
  6. from chainer_addons.training import optimizer as new_optimizer
  7. from chainer_addons.training.optimizer_hooks import SelectiveWeightDecay
  8. from cvdatasets.utils import pretty_print_dict
  9. from cvfinetune.finetuner.mixins.base import BaseMixin
  10. def check_param_for_decay(param):
  11. return param.name != "alpha"
  12. def enable_only_head(chain: chainer.Chain):
  13. if hasattr(chain, "enable_only_head") and callable(chain.enable_only_head):
  14. chain.enable_only_head()
  15. else:
  16. chain.disable_update()
  17. chain.fc.enable_update()
  18. class _OptimizerCreator:
  19. def __init__(self, opt, **kwargs):
  20. super().__init__()
  21. self.opt = opt
  22. self.kwargs = kwargs
  23. def __call__(self, *args, **kwargs):
  24. if self.opt is None:
  25. return None
  26. kwargs = dict(self.kwargs, **kwargs)
  27. return new_optimizer(self.opt, *args, **kwargs)
  28. class _OptimizerMixin(BaseMixin):
  29. def __init__(self, *args,
  30. optimizer: str,
  31. learning_rate: float = 1e-3,
  32. weight_decay: float = 5e-4,
  33. eps: float = 1e-2,
  34. only_head: bool = False,
  35. **kwargs):
  36. super().__init__(*args, **kwargs)
  37. optimizer_kwargs = dict(decay=0, gradient_clipping=False)
  38. if optimizer in ["rmsprop", "adam"]:
  39. optimizer_kwargs["eps"] = eps
  40. self._opt_creator = _OptimizerCreator(optimizer, **optimizer_kwargs)
  41. self.learning_rate = learning_rate
  42. self.weight_decay = weight_decay
  43. self._only_head = only_head
  44. def init_optimizer(self):
  45. """Creates an optimizer for the classifier """
  46. self._check_attr("clf")
  47. self._check_attr("_pooling")
  48. self._check_attr("_l1_loss")
  49. self.opt = self._opt_creator(self.clf, self.learning_rate)
  50. if self.opt is None:
  51. logging.warning("========= No optimizer was initialized! =========")
  52. return
  53. kwargs = self._opt_creator.kwargs
  54. logging.info(
  55. f"Initialized {type(self.opt).__name__} optimizer"
  56. f" with initial LR {self.learning_rate} and kwargs: {pretty_print_dict(kwargs)}"
  57. )
  58. self.init_regularizer()
  59. if self._only_head:
  60. logging.warning("========= Fine-tuning only classifier layer! =========")
  61. enable_only_head(self.clf)
  62. def init_regularizer(self, **kwargs):
  63. if self.weight_decay <= 0:
  64. return
  65. if self._l1_loss:
  66. cls = Lasso
  67. elif self._pooling == "alpha":
  68. cls = SelectiveWeightDecay
  69. kwargs["selection"] = check_param_for_decay
  70. else:
  71. cls = WeightDecay
  72. logging.info(f"Adding {cls.__name__} ({self.weight_decay:e})")
  73. self.opt.add_hook(cls(self.weight_decay, **kwargs))