separate_model_classifier.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import abc
  2. import chainer
  3. from typing import Callable
  4. from cvmodelz.models.base import BaseModel
  5. from cvmodelz.classifiers.base import Classifier
  6. class SeparateModelClassifier(Classifier):
  7. """
  8. Abstract Classifier, that holds two separate models.
  9. The user has to define, how these models operate on the
  10. input data. Hence, the forward method is abstract!
  11. """
  12. @abc.abstractmethod
  13. def forward(self, *args, **kwargs) -> chainer.Variable:
  14. super(SeparateModelClassifier, self).forward(*args, **kwargs)
  15. def setup(self, model: BaseModel) -> None:
  16. super(SeparateModelClassifier, self).setup(model)
  17. self.separate_model = self.model.copy(mode="copy")
  18. def loader(self, model_loader: Callable) -> Callable:
  19. super_loader = super(SeparateModelClassifier).loader(model_loader)
  20. def inner_loader(n_classes: int, feat_size: int) -> None:
  21. # use the given feature size here
  22. super_loader(n_classes=n_classes, feat_size=feat_size)
  23. # use the given feature size first ...
  24. self.separate_model.reinitialize_clf(
  25. n_classes=n_classes,
  26. feat_size=feat_size)
  27. # then copy model params ...
  28. self.separate_model.copyparams(self.model)
  29. # now use the default feature size to re-init the classifier
  30. self.separate_model.reinitialize_clf(
  31. n_classes=n_classes,
  32. feat_size=self.feat_size)
  33. return inner_loader
  34. def enable_only_head(self) -> None:
  35. super(SeparateModelClassifier, self).enable_only_head()
  36. self.separate_model.disable_update()
  37. self.separate_model.fc.enable_update()