12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import abc
- import chainer
- from typing import Callable
- from cvmodelz.models.base import BaseModel
- from cvmodelz.classifiers.base import Classifier
- class SeparateModelClassifier(Classifier):
- """
- Abstract Classifier, that holds two separate models.
- The user has to define, how these models operate on the
- input data. Hence, the forward method is abstract!
- """
- @abc.abstractmethod
- def forward(self, *args, **kwargs) -> chainer.Variable:
- super(SeparateModelClassifier, self).forward(*args, **kwargs)
- def setup(self, model: BaseModel) -> None:
- super(SeparateModelClassifier, self).setup(model)
- self.separate_model = self.model.copy(mode="copy")
- def loader(self, model_loader: Callable) -> Callable:
- super_loader = super(SeparateModelClassifier).loader(model_loader)
- def inner_loader(n_classes: int, feat_size: int) -> None:
- # use the given feature size here
- super_loader(n_classes=n_classes, feat_size=feat_size)
- # use the given feature size first ...
- self.separate_model.reinitialize_clf(
- n_classes=n_classes,
- feat_size=feat_size)
- # then copy model params ...
- self.separate_model.copyparams(self.model)
- # now use the default feature size to re-init the classifier
- self.separate_model.reinitialize_clf(
- n_classes=n_classes,
- feat_size=self.feat_size)
- return inner_loader
- def enable_only_head(self) -> None:
- super(SeparateModelClassifier, self).enable_only_head()
- self.separate_model.disable_update()
- self.separate_model.fc.enable_update()
|