12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import abc
- import chainer.functions as F
- import chainer.links as L
- from cvmodelz import utils
- class BaseModel(abc.ABC):
- @abc.abstractmethod
- def __call__(self, X, layer_name=None):
- pass
- @abc.abstractproperty
- def model_instance(self):
- raise NotImplementedError()
- @property
- def clf_layer_name(self):
- return self.meta.classifier_layers[-1]
- @property
- def clf_layer(self):
- return utils.get_attr_from_path(self.model_instance, self.clf_layer_name)
- def loss(self, pred, gt, loss_func=F.softmax_cross_entropy):
- return loss_func(pred, gt)
- def accuracy(self, pred, gt):
- return F.accuracy(pred, gt)
- def reinitialize_clf(self, n_classes, feat_size=None, initializer=None):
- if initializer is None or not callable(initializer):
- initializer = HeNormal(scale=1.0)
- clf_layer = self.clf_layer
- assert isinstance(clf_layer, L.Linear)
- w_shape = (n_classes, feat_size or clf_layer.W.shape[1])
- dtype = clf_layer.W.dtype
- clf_layer.W.data = np.zeros(w_shape, dtype=dtype)
- clf_layer.b.data = np.zeros(w_shape[0], dtype=dtype)
- initializer(clf_layer.W.data)
- def load_for_finetune(self, weights, n_classes, *, path="", strict=False, headless=False, **kwargs):
- """
- The weights should be pre-trained on a bigger
- dataset (eg. ImageNet). The classification layer is
- reinitialized after all other weights are loaded
- """
- self.load(weights, path=path, strict=strict, headless=headless)
- self.reinitialize_clf(n_classes, **kwargs)
- def load_for_inference(self, weights, n_classes, *, path="", strict=False, headless=False, **kwargs):
- """
- In this use case we are loading already fine-tuned
- weights. This means, we need to reinitialize the
- classification layer first and then load the weights.
- """
- self.reinitialize_clf(n_classes, **kwargs)
- self.load(weights, path=path, strict=strict, headless=headless)
- def load(self, weights, *, path="", strict=False, headless=False):
- if weights not in [None, "auto"]:
- ignore_names = None
- if headless:
- ignore_names = lambda name: name.startswith(path + self.clf_layer_name)
- npz.load_npz(weights, self.model_instance,
- path=path, strict=strict, ignore_names=ignore_names)
|