|
@@ -1,10 +1,12 @@
|
|
|
import abc
|
|
|
import chainer
|
|
|
|
|
|
+from chainer import functions as F
|
|
|
from typing import Callable
|
|
|
+from typing import Dict
|
|
|
|
|
|
-from cvmodelz.models.base import BaseModel
|
|
|
from cvmodelz.classifiers.base import Classifier
|
|
|
+from cvmodelz.models.base import BaseModel
|
|
|
|
|
|
|
|
|
|
|
@@ -61,3 +63,41 @@ class SeparateModelClassifier(Classifier):
|
|
|
super(SeparateModelClassifier, self).enable_only_head()
|
|
|
self.separate_model.disable_update()
|
|
|
self.separate_model.fc.enable_update()
|
|
|
+
|
|
|
+
|
|
|
+class MeanModelClassifier(SeparateModelClassifier):
|
|
|
+
|
|
|
+ def evaluations(self, pred0: chainer.Variable, pred1: chainer.Variable, y: chainer.Variable) -> Dict[str, chainer.Variable]:
|
|
|
+ accu0 = self.model.accuracy(pred0, y)
|
|
|
+ accu1 = self.separate_model.accuracy(pred1, y)
|
|
|
+
|
|
|
+ mean_pred = (F.softmax(pred0) + F.softmax(pred1)) / 2
|
|
|
+
|
|
|
+ accu = self.model.accuracy(mean_pred, y)
|
|
|
+
|
|
|
+ return dict(
|
|
|
+ accu0=accu0,
|
|
|
+ accu1=accu1,
|
|
|
+ accuracy=accu,
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, X: chainer.Variable, y: chainer.Variable) -> chainer.Variable:
|
|
|
+
|
|
|
+ pred0 = self.model(X, layer_name=self.layer_name)
|
|
|
+ pred1 = self.separate_model(X, layer_name=self.layer_name)
|
|
|
+
|
|
|
+ loss0, loss1 = self.loss(pred0, y), self.loss(pred1, y)
|
|
|
+ loss = (loss0 + loss1) / 2
|
|
|
+
|
|
|
+ evaluations = self.evaluations(pred0, pred1, y)
|
|
|
+
|
|
|
+ self.report(
|
|
|
+ loss0=loss0,
|
|
|
+ loss1=loss1,
|
|
|
+ loss=loss,
|
|
|
+ **evaluations
|
|
|
+ )
|
|
|
+ return loss
|
|
|
+
|
|
|
+
|
|
|
+
|