import abc import chainer from typing import Callable from cvmodelz.models.base import BaseModel from cvmodelz.classifiers.base import Classifier class SeparateModelClassifier(Classifier): """ Abstract Classifier, that holds two separate models. The user has to define, how these models operate on the input data. Hence, the forward method is abstract! """ @abc.abstractmethod def forward(self, *args, **kwargs) -> chainer.Variable: super(SeparateModelClassifier, self).forward(*args, **kwargs) def setup(self, model: BaseModel) -> None: super(SeparateModelClassifier, self).setup(model) self.separate_model = self.model.copy(mode="copy") def loader(self, model_loader: Callable) -> Callable: super_loader = super(SeparateModelClassifier).loader(model_loader) def inner_loader(n_classes: int, feat_size: int) -> None: # use the given feature size here super_loader(n_classes=n_classes, feat_size=feat_size) # use the given feature size first ... self.separate_model.reinitialize_clf( n_classes=n_classes, feat_size=feat_size) # then copy model params ... self.separate_model.copyparams(self.model) # now use the default feature size to re-init the classifier self.separate_model.reinitialize_clf( n_classes=n_classes, feat_size=self.feat_size) return inner_loader def enable_only_head(self) -> None: super(SeparateModelClassifier, self).enable_only_head() self.separate_model.disable_update() self.separate_model.fc.enable_update()