classifier.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import abc
  2. import chainer
  3. import chainer.functions as F
  4. import chainer.links as L
  5. import logging
  6. from chainer_addons.models.classifier import Classifier as C
  7. class Classifier(C):
  8. def __init__(self, *args, **kwargs):
  9. super(Classifier, self).__init__(*args, **kwargs)
  10. assert hasattr(self, "model"), \
  11. "This classifiert has no \"model\" attribute!"
  12. @property
  13. def feat_size(self):
  14. if hasattr(self.model.pool, "output_dim") and self.model.pool.output_dim is not None:
  15. return self.model.pool.output_dim
  16. return self.model.meta.feature_size
  17. @property
  18. def output_size(self):
  19. return self.feat_size
  20. class SeparateModelClassifier(Classifier):
  21. """Classifier, that holds two separate models"""
  22. def __init__(self, *args, **kwargs):
  23. super(SeparateModelClassifier, self).__init__(*args, **kwargs)
  24. with self.init_scope():
  25. self.init_separate_model()
  26. @abc.abstractmethod
  27. def __call__(self, *args, **kwargs):
  28. super(SeparateModelClassifier, self).__call__(*args, **kwargs)
  29. def init_separate_model(self):
  30. if hasattr(self, "separate_model"):
  31. logging.warn("Global Model already initialized! Skipping further execution!")
  32. return
  33. self.separate_model = self.model.copy(mode="copy")
  34. def loader(self, model_loader):
  35. def inner(n_classes, feat_size):
  36. # use the given feature size here
  37. model_loader(n_classes=n_classes, feat_size=feat_size)
  38. # use the given feature size first ...
  39. self.separate_model.reinitialize_clf(
  40. n_classes=n_classes,
  41. feat_size=feat_size)
  42. # then copy model params ...
  43. self.separate_model.copyparams(self.model)
  44. # now use the default feature size to re-init the classifier
  45. self.separate_model.reinitialize_clf(
  46. n_classes=n_classes,
  47. feat_size=self.feat_size)
  48. return inner