base.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import abc
  2. import chainer
  3. from chainer import functions as F
  4. from chainer.serializers import npz
  5. from typing import Callable
  6. from typing import Dict
  7. from cvmodelz.models.base import BaseModel
  8. class Classifier(chainer.Chain):
  9. def __init__(self, model: BaseModel, *,
  10. layer_name: str = None,
  11. loss_func: Callable = F.softmax_cross_entropy,
  12. only_head: bool = False,
  13. ):
  14. super().__init__()
  15. self.layer_name = layer_name or model.clf_layer_name
  16. self.loss_func = loss_func
  17. with self.init_scope():
  18. self.setup(model)
  19. if only_head:
  20. self.enable_only_head()
  21. def setup(self, model: BaseModel) -> None:
  22. self.model = model
  23. def report(self, **values) -> None:
  24. chainer.report(values, self)
  25. def enable_only_head(self) -> None:
  26. self.model.disable_update()
  27. self.model.clf_layer.enable_update()
  28. @property
  29. def n_classes(self) -> int:
  30. return self.model.clf_layer.W.shape[0]
  31. def save(self, weights_file):
  32. npz.save_npz(weights_file, self)
  33. def load(self, weights_file: str, n_classes: int, *, finetune: bool = False) -> None:
  34. """ Loading a classifier has following use cases:
  35. (0) No loading.
  36. Here the all weights are initilized randomly.
  37. (1) Loading from default pre-trained weights
  38. Here, the weights are loaded directly to
  39. the model. Any additional not model-related
  40. layer will be initialized randomly.
  41. (2) Loading from a saved classifier.
  42. Here, all weights are loaded as-it-is from
  43. the given file.
  44. """
  45. try:
  46. # Case (2)
  47. self.load_classifier(weights_file)
  48. except KeyError as e:
  49. # Case (1)
  50. self.load_model(weights_file, n_classes=n_classes, finetune=finetune)
  51. else:
  52. # Case (0)
  53. pass
  54. def load_classifier(self, weights_file):
  55. npz.load_npz(weights_file, self, strict=True)
  56. def load_model(self, weights_file, n_classes, *, finetune: bool = False):
  57. if finetune:
  58. model_loader = self.model.load_for_finetune
  59. else:
  60. model_loader = self.model.load_for_inference
  61. try:
  62. model_loader(weights=weights_file, n_classes=n_classes, strict=True)
  63. except KeyError as e:
  64. breakpoint()
  65. raise
  66. @property
  67. def feat_size(self) -> int:
  68. return self.model.meta.feature_size
  69. @property
  70. def output_size(self) -> int:
  71. return self.feat_size
  72. def loss(self, pred: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
  73. return self.model.loss(pred, y, loss_func=self.loss_func)
  74. def evaluations(self, pred: chainer.Variable, y: chainer.Variable) -> Dict[str, chainer.Variable]:
  75. return dict(accuracy=self.model.accuracy(pred, y))
  76. def forward(self, X: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
  77. pred = self.model(X, layer_name=self.layer_name)
  78. loss = self.loss(pred, y)
  79. evaluations = self.evaluations(pred, y)
  80. self.report(loss=loss, **evaluations)
  81. return loss