|
@@ -0,0 +1,63 @@
|
|
|
+import abc
|
|
|
+import chainer
|
|
|
+
|
|
|
+from chainer import functions as F
|
|
|
+from typing import Dict
|
|
|
+from typing import Callable
|
|
|
+
|
|
|
+from cvmodelz.models.base import BaseModel
|
|
|
+
|
|
|
+class Classifier(chainer.Chain):
|
|
|
+
|
|
|
+ def __init__(self, model: BaseModel, *,
|
|
|
+ layer_name: str = None,
|
|
|
+ loss_func: Callable = F.softmax_cross_entropy,
|
|
|
+ only_head: bool = False,
|
|
|
+ ):
|
|
|
+ super(BaseClassifier, self).__init__()
|
|
|
+ self.layer_name = layer_name or model.meta.clf_layer_name
|
|
|
+ self.loss_func = loss_func
|
|
|
+
|
|
|
+ with self.init_scope():
|
|
|
+ self.setup(model)
|
|
|
+
|
|
|
+ if only_head:
|
|
|
+ self.enable_only_head()
|
|
|
+
|
|
|
+ def setup(self, model: BaseModel) -> None:
|
|
|
+ self.model = model
|
|
|
+
|
|
|
+ def report(self, **values) -> None:
|
|
|
+ chainer.report(values, self)
|
|
|
+
|
|
|
+ def enable_only_head(self) -> None:
|
|
|
+ self.model.disable_update()
|
|
|
+ self.model.clf_layer.enable_update()
|
|
|
+
|
|
|
+ def loader(self, model_loader: Callable) -> Callable:
|
|
|
+ return model_loader
|
|
|
+
|
|
|
+ @property
|
|
|
+ def feat_size(self) -> int:
|
|
|
+ return self.model.meta.feature_size
|
|
|
+
|
|
|
+ @property
|
|
|
+ def output_size(self) -> int:
|
|
|
+ return self.feat_size
|
|
|
+
|
|
|
+ def loss(self, pred: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
|
|
|
+ return self.model.loss(pred, y, loss_func=self.loss_func)
|
|
|
+
|
|
|
+ def evaluations(self, pred: chainer.Variable, y: chainer.Variable) -> Dict[str, chainer.Variable]:
|
|
|
+ return dict(accuracy=self.model.accuracy(pred, y))
|
|
|
+
|
|
|
+ def forward(self, X: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
|
|
|
+ pred = self.model(X, layer_name=self.layer_name)
|
|
|
+
|
|
|
+ loss = self.loss(pred, y)
|
|
|
+ evaluations = self.evaluations(pred, y)
|
|
|
+
|
|
|
+ self.report(loss=loss, **evaluations)
|
|
|
+ return loss
|
|
|
+
|
|
|
+
|