Переглянути джерело

added some classifier implementations

Dimitri Korsch 4 роки тому
батько
коміт
bb60371760

+ 0 - 0
cvmodelz/classifiers/__init__.py


+ 63 - 0
cvmodelz/classifiers/base.py

@@ -0,0 +1,63 @@
+import abc
+import chainer
+
+from chainer import functions as F
+from typing import Dict
+from typing import Callable
+
+from cvmodelz.models.base import BaseModel
+
+class Classifier(chainer.Chain):
+
+	def __init__(self, model: BaseModel, *,
+		layer_name: str = None,
+		loss_func: Callable = F.softmax_cross_entropy,
+		only_head: bool = False,
+		):
+		super(BaseClassifier, self).__init__()
+		self.layer_name = layer_name or model.meta.clf_layer_name
+		self.loss_func = loss_func
+
+		with self.init_scope():
+			self.setup(model)
+
+		if only_head:
+			self.enable_only_head()
+
+	def setup(self, model: BaseModel) -> None:
+		self.model = model
+
+	def report(self, **values) -> None:
+		chainer.report(values, self)
+
+	def enable_only_head(self) -> None:
+		self.model.disable_update()
+		self.model.clf_layer.enable_update()
+
+	def loader(self, model_loader: Callable) -> Callable:
+		return model_loader
+
+	@property
+	def feat_size(self) -> int:
+		return self.model.meta.feature_size
+
+	@property
+	def output_size(self) -> int:
+		return self.feat_size
+
+	def loss(self, pred: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
+		return self.model.loss(pred, y, loss_func=self.loss_func)
+
+	def evaluations(self, pred: chainer.Variable, y: chainer.Variable) -> Dict[str, chainer.Variable]:
+		return dict(accuracy=self.model.accuracy(pred, y))
+
+	def forward(self, X: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
+		pred = self.model(X, layer_name=self.layer_name)
+
+		loss = self.loss(pred, y)
+		evaluations = self.evaluations(pred, y)
+
+		self.report(loss=loss, **evaluations)
+		return loss
+
+

+ 52 - 0
cvmodelz/classifiers/separate_model_classifier.py

@@ -0,0 +1,52 @@
+import abc
+import chainer
+
+from typing import Callable
+
+from cvmodelz.models.base import BaseModel
+from cvmodelz.classifiers.base import Classifier
+
+
+
+class SeparateModelClassifier(Classifier):
+	"""
+		Abstract Classifier, that holds two separate models.
+		The user has to define, how these models operate on the
+		input data. Hence, the forward method is abstract!
+	"""
+
+	@abc.abstractmethod
+	def forward(self, *args, **kwargs) -> chainer.Variable:
+		super(SeparateModelClassifier, self).forward(*args, **kwargs)
+
+	def setup(self, model: BaseModel) -> None:
+		super(SeparateModelClassifier, self).setup(model)
+
+		self.separate_model = self.model.copy(mode="copy")
+
+	def loader(self, model_loader: Callable) -> Callable:
+		super_loader = super(SeparateModelClassifier).loader(model_loader)
+
+		def inner_loader(n_classes: int, feat_size: int) -> None:
+			# use the given feature size here
+			super_loader(n_classes=n_classes, feat_size=feat_size)
+
+			# use the given feature size first ...
+			self.separate_model.reinitialize_clf(
+				n_classes=n_classes,
+				feat_size=feat_size)
+
+			# then copy model params ...
+			self.separate_model.copyparams(self.model)
+
+			# now use the default feature size to re-init the classifier
+			self.separate_model.reinitialize_clf(
+				n_classes=n_classes,
+				feat_size=self.feat_size)
+
+		return inner_loader
+
+	def enable_only_head(self) -> None:
+		super(SeparateModelClassifier, self).enable_only_head()
+		self.separate_model.disable_update()
+		self.separate_model.fc.enable_update()

+ 1 - 1
cvmodelz/models/base.py

@@ -10,7 +10,7 @@ from typing import Callable
 from cvmodelz import utils
 from cvmodelz.models.meta_info import ModelInfo
 
-class BaseModel(abc.ABC):
+class BaseModel(abc.ABC, chainer.Chain):
 
 	def __init__(self, pooling: Callable = PoolingType.G_AVG.value(), input_size=None, *args, **kwargs):
 		super(BaseModel, self).__init__(*args, **kwargs)

+ 1 - 1
cvmodelz/models/wrapper.py

@@ -8,7 +8,7 @@ from cvmodelz.models.meta_info import ModelInfo
 
 
 
-class ModelWrapper(BaseModel, chainer.Chain):
+class ModelWrapper(BaseModel):
 	"""
 		This class is designed to wrap around chainercv2 models
 		and provide the loading API of the BaseModel class.