Bläddra i källkod

added some classifier implementations

Dimitri Korsch 4 år sedan
förälder
incheckning
bb60371760

+ 0 - 0
cvmodelz/classifiers/__init__.py


+ 63 - 0
cvmodelz/classifiers/base.py

@@ -0,0 +1,63 @@
+import abc
+import chainer
+
+from chainer import functions as F
+from typing import Dict
+from typing import Callable
+
+from cvmodelz.models.base import BaseModel
+
+class Classifier(chainer.Chain):
+
+	def __init__(self, model: BaseModel, *,
+		layer_name: str = None,
+		loss_func: Callable = F.softmax_cross_entropy,
+		only_head: bool = False,
+		):
+		super(BaseClassifier, self).__init__()
+		self.layer_name = layer_name or model.meta.clf_layer_name
+		self.loss_func = loss_func
+
+		with self.init_scope():
+			self.setup(model)
+
+		if only_head:
+			self.enable_only_head()
+
+	def setup(self, model: BaseModel) -> None:
+		self.model = model
+
+	def report(self, **values) -> None:
+		chainer.report(values, self)
+
+	def enable_only_head(self) -> None:
+		self.model.disable_update()
+		self.model.clf_layer.enable_update()
+
+	def loader(self, model_loader: Callable) -> Callable:
+		return model_loader
+
+	@property
+	def feat_size(self) -> int:
+		return self.model.meta.feature_size
+
+	@property
+	def output_size(self) -> int:
+		return self.feat_size
+
+	def loss(self, pred: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
+		return self.model.loss(pred, y, loss_func=self.loss_func)
+
+	def evaluations(self, pred: chainer.Variable, y: chainer.Variable) -> Dict[str, chainer.Variable]:
+		return dict(accuracy=self.model.accuracy(pred, y))
+
+	def forward(self, X: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
+		pred = self.model(X, layer_name=self.layer_name)
+
+		loss = self.loss(pred, y)
+		evaluations = self.evaluations(pred, y)
+
+		self.report(loss=loss, **evaluations)
+		return loss
+
+

+ 52 - 0
cvmodelz/classifiers/separate_model_classifier.py

@@ -0,0 +1,52 @@
+import abc
+import chainer
+
+from typing import Callable
+
+from cvmodelz.models.base import BaseModel
+from cvmodelz.classifiers.base import Classifier
+
+
+
+class SeparateModelClassifier(Classifier):
+	"""
+		Abstract Classifier, that holds two separate models.
+		The user has to define, how these models operate on the
+		input data. Hence, the forward method is abstract!
+	"""
+
+	@abc.abstractmethod
+	def forward(self, *args, **kwargs) -> chainer.Variable:
+		super(SeparateModelClassifier, self).forward(*args, **kwargs)
+
+	def setup(self, model: BaseModel) -> None:
+		super(SeparateModelClassifier, self).setup(model)
+
+		self.separate_model = self.model.copy(mode="copy")
+
+	def loader(self, model_loader: Callable) -> Callable:
+		super_loader = super(SeparateModelClassifier).loader(model_loader)
+
+		def inner_loader(n_classes: int, feat_size: int) -> None:
+			# use the given feature size here
+			super_loader(n_classes=n_classes, feat_size=feat_size)
+
+			# use the given feature size first ...
+			self.separate_model.reinitialize_clf(
+				n_classes=n_classes,
+				feat_size=feat_size)
+
+			# then copy model params ...
+			self.separate_model.copyparams(self.model)
+
+			# now use the default feature size to re-init the classifier
+			self.separate_model.reinitialize_clf(
+				n_classes=n_classes,
+				feat_size=self.feat_size)
+
+		return inner_loader
+
+	def enable_only_head(self) -> None:
+		super(SeparateModelClassifier, self).enable_only_head()
+		self.separate_model.disable_update()
+		self.separate_model.fc.enable_update()

+ 1 - 1
cvmodelz/models/base.py

@@ -10,7 +10,7 @@ from typing import Callable
 from cvmodelz import utils
 from cvmodelz import utils
 from cvmodelz.models.meta_info import ModelInfo
 from cvmodelz.models.meta_info import ModelInfo
 
 
-class BaseModel(abc.ABC):
+class BaseModel(abc.ABC, chainer.Chain):
 
 
 	def __init__(self, pooling: Callable = PoolingType.G_AVG.value(), input_size=None, *args, **kwargs):
 	def __init__(self, pooling: Callable = PoolingType.G_AVG.value(), input_size=None, *args, **kwargs):
 		super(BaseModel, self).__init__(*args, **kwargs)
 		super(BaseModel, self).__init__(*args, **kwargs)

+ 1 - 1
cvmodelz/models/wrapper.py

@@ -8,7 +8,7 @@ from cvmodelz.models.meta_info import ModelInfo
 
 
 
 
 
 
-class ModelWrapper(BaseModel, chainer.Chain):
+class ModelWrapper(BaseModel):
 	"""
 	"""
 		This class is designed to wrap around chainercv2 models
 		This class is designed to wrap around chainercv2 models
 		and provide the loading API of the BaseModel class.
 		and provide the loading API of the BaseModel class.