base.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import abc
  2. import chainer.functions as F
  3. import chainer.links as L
  4. from cvmodelz import utils
  5. class BaseModel(abc.ABC):
  6. @abc.abstractmethod
  7. def __call__(self, X, layer_name=None):
  8. pass
  9. @abc.abstractproperty
  10. def model_instance(self):
  11. raise NotImplementedError()
  12. @property
  13. def clf_layer_name(self):
  14. return self.meta.classifier_layers[-1]
  15. @property
  16. def clf_layer(self):
  17. return utils.get_attr_from_path(self.model_instance, self.clf_layer_name)
  18. def loss(self, pred, gt, loss_func=F.softmax_cross_entropy):
  19. return loss_func(pred, gt)
  20. def accuracy(self, pred, gt):
  21. return F.accuracy(pred, gt)
  22. def reinitialize_clf(self, n_classes, feat_size=None, initializer=None):
  23. if initializer is None or not callable(initializer):
  24. initializer = HeNormal(scale=1.0)
  25. clf_layer = self.clf_layer
  26. assert isinstance(clf_layer, L.Linear)
  27. w_shape = (n_classes, feat_size or clf_layer.W.shape[1])
  28. dtype = clf_layer.W.dtype
  29. clf_layer.W.data = np.zeros(w_shape, dtype=dtype)
  30. clf_layer.b.data = np.zeros(w_shape[0], dtype=dtype)
  31. initializer(clf_layer.W.data)
  32. def load_for_finetune(self, weights, n_classes, *, path="", strict=False, headless=False, **kwargs):
  33. """
  34. The weights should be pre-trained on a bigger
  35. dataset (eg. ImageNet). The classification layer is
  36. reinitialized after all other weights are loaded
  37. """
  38. self.load(weights, path=path, strict=strict, headless=headless)
  39. self.reinitialize_clf(n_classes, **kwargs)
  40. def load_for_inference(self, weights, n_classes, *, path="", strict=False, headless=False, **kwargs):
  41. """
  42. In this use case we are loading already fine-tuned
  43. weights. This means, we need to reinitialize the
  44. classification layer first and then load the weights.
  45. """
  46. self.reinitialize_clf(n_classes, **kwargs)
  47. self.load(weights, path=path, strict=strict, headless=headless)
  48. def load(self, weights, *, path="", strict=False, headless=False):
  49. if weights not in [None, "auto"]:
  50. ignore_names = None
  51. if headless:
  52. ignore_names = lambda name: name.startswith(path + self.clf_layer_name)
  53. npz.load_npz(weights, self.model_instance,
  54. path=path, strict=strict, ignore_names=ignore_names)