separate_model_classifier.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import abc
  2. import chainer
  3. from chainer import functions as F
  4. from typing import Callable
  5. from typing import Dict
  6. from cvmodelz.classifiers.base import Classifier
  7. from cvmodelz.models.base import BaseModel
  8. class SeparateModelClassifier(Classifier):
  9. """
  10. Abstract Classifier, that holds two separate models.
  11. The user has to define, how these models operate on the
  12. input data. Hence, the forward method is abstract!
  13. """
  14. @abc.abstractmethod
  15. def forward(self, *args, **kwargs) -> chainer.Variable:
  16. super().forward(*args, **kwargs)
  17. def setup(self, model: BaseModel) -> None:
  18. super().setup(model)
  19. self.separate_model = self.model.copy(mode="copy")
  20. def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False, **kwargs) -> None:
  21. for model in [self.model, self.separate_model]:
  22. model_loader = self.get_model_loader(finetune=finetune, model=model)
  23. kwargs["strict"] = kwargs.get("strict", True)
  24. model_loader(weights=weights_file, n_classes=n_classes, **kwargs)
  25. def enable_only_head(self) -> None:
  26. super().enable_only_head()
  27. self.separate_model.disable_update()
  28. self.separate_model.fc.enable_update()
  29. class MeanModelClassifier(SeparateModelClassifier):
  30. def evaluations(self, pred0: chainer.Variable, pred1: chainer.Variable, y: chainer.Variable) -> Dict[str, chainer.Variable]:
  31. accu0 = self.model.accuracy(pred0, y)
  32. accu1 = self.separate_model.accuracy(pred1, y)
  33. mean_pred = (F.softmax(pred0) + F.softmax(pred1)) / 2
  34. accu = self.model.accuracy(mean_pred, y)
  35. return dict(
  36. accu0=accu0,
  37. accu1=accu1,
  38. accuracy=accu,
  39. )
  40. def forward(self, X: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
  41. pred0 = self.model(X, layer_name=self.layer_name)
  42. pred1 = self.separate_model(X, layer_name=self.layer_name)
  43. loss0, loss1 = self.loss(pred0, y), self.loss(pred1, y)
  44. loss = (loss0 + loss1) / 2
  45. evaluations = self.evaluations(pred0, pred1, y)
  46. self.report(
  47. loss0=loss0,
  48. loss1=loss1,
  49. loss=loss,
  50. **evaluations
  51. )
  52. return loss