123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import abc
- import chainer
- from chainer import functions as F
- from chainer.serializers import npz
- from typing import Callable
- from typing import Dict
- from cvmodelz.models.base import BaseModel
- class Classifier(chainer.Chain):
- def __init__(self, model: BaseModel, *,
- layer_name: str = None,
- loss_func: Callable = F.softmax_cross_entropy,
- only_head: bool = False,
- ):
- super().__init__()
- self.layer_name = layer_name or model.clf_layer_name
- self.loss_func = loss_func
- with self.init_scope():
- self.setup(model)
- if only_head:
- self.enable_only_head()
- def setup(self, model: BaseModel) -> None:
- self.model = model
- def report(self, **values) -> None:
- chainer.report(values, self)
- def enable_only_head(self) -> None:
- self.model.disable_update()
- self.model.clf_layer.enable_update()
- @property
- def n_classes(self) -> int:
- return self.model.clf_layer.W.shape[0]
- def save(self, weights_file):
- npz.save_npz(weights_file, self)
- def load(self, weights_file: str, n_classes: int, *, finetune: bool = False) -> None:
- """ Loading a classifier has following use cases:
- (0) No loading.
- Here the all weights are initilized randomly.
- (1) Loading from default pre-trained weights
- Here, the weights are loaded directly to
- the model. Any additional not model-related
- layer will be initialized randomly.
- (2) Loading from a saved classifier.
- Here, all weights are loaded as-it-is from
- the given file.
- """
- try:
- # Case (2)
- self.load_classifier(weights_file)
- except KeyError as e:
- # Case (1)
- self.load_model(weights_file, n_classes=n_classes, finetune=finetune)
- else:
- # Case (0)
- pass
- def load_classifier(self, weights_file):
- npz.load_npz(weights_file, self, strict=True)
- def load_model(self, weights_file, n_classes, *, finetune: bool = False):
- if finetune:
- model_loader = self.model.load_for_finetune
- else:
- model_loader = self.model.load_for_inference
- try:
- model_loader(weights=weights_file, n_classes=n_classes, strict=True)
- except KeyError as e:
- breakpoint()
- raise
- @property
- def feat_size(self) -> int:
- return self.model.meta.feature_size
- @property
- def output_size(self) -> int:
- return self.feat_size
- def loss(self, pred: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
- return self.model.loss(pred, y, loss_func=self.loss_func)
- def evaluations(self, pred: chainer.Variable, y: chainer.Variable) -> Dict[str, chainer.Variable]:
- return dict(accuracy=self.model.accuracy(pred, y))
- def forward(self, X: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
- pred = self.model(X, layer_name=self.layer_name)
- loss = self.loss(pred, y)
- evaluations = self.evaluations(pred, y)
- self.report(loss=loss, **evaluations)
- return loss
|