|
@@ -4,6 +4,7 @@ import test_utils
|
|
import unittest
|
|
import unittest
|
|
|
|
|
|
from cvmodelz.classifiers import Classifier
|
|
from cvmodelz.classifiers import Classifier
|
|
|
|
+from cvmodelz.classifiers import MeanModelClassifier
|
|
from cvmodelz.models import ModelFactory
|
|
from cvmodelz.models import ModelFactory
|
|
|
|
|
|
class ClassifierLoadingTests(unittest.TestCase):
|
|
class ClassifierLoadingTests(unittest.TestCase):
|
|
@@ -28,6 +29,7 @@ class ClassifierLoadingTests(unittest.TestCase):
|
|
""" if finetune is True, then the shapes of the classification
|
|
""" if finetune is True, then the shapes of the classification
|
|
layer differ, hence, strict should be False """
|
|
layer differ, hence, strict should be False """
|
|
self.assertTrue(*test_utils.is_all_equal(model, clf.model, strict=not finetune))
|
|
self.assertTrue(*test_utils.is_all_equal(model, clf.model, strict=not finetune))
|
|
|
|
+ return model, clf
|
|
|
|
|
|
def load_for_finetune(self, key):
|
|
def load_for_finetune(self, key):
|
|
self.load_model(key, finetune=True)
|
|
self.load_model(key, finetune=True)
|
|
@@ -35,8 +37,29 @@ class ClassifierLoadingTests(unittest.TestCase):
|
|
def load_for_inference(self, key):
|
|
def load_for_inference(self, key):
|
|
self.load_model(key, finetune=False)
|
|
self.load_model(key, finetune=False)
|
|
|
|
|
|
|
|
+class MeanModelClassifierLoadingTests(ClassifierLoadingTests):
|
|
|
|
+
|
|
|
|
+ def new_clf(self, key, **kwargs):
|
|
|
|
+ model = ModelFactory.new(key, pretrained_model=None)
|
|
|
|
+ return model, MeanModelClassifier(model, **kwargs)
|
|
|
|
+
|
|
|
|
+ def load_for_finetune(self, key):
|
|
|
|
+ _, clf = self.load_model(key, finetune=True)
|
|
|
|
+ self.assertTrue(*test_utils.is_all_equal(clf.model, clf.separate_model, strict=True, exclude_clf=True))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ def load_for_inference(self, key):
|
|
|
|
+ _, clf = self.load_model(key, finetune=False)
|
|
|
|
+ self.assertTrue(*test_utils.is_all_equal(clf.model, clf.separate_model, strict=True))
|
|
|
|
+
|
|
test_utils.add_tests(ClassifierLoadingTests.load_for_finetune,
|
|
test_utils.add_tests(ClassifierLoadingTests.load_for_finetune,
|
|
model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
|
|
|
|
|
|
+test_utils.add_tests(MeanModelClassifierLoadingTests.load_for_finetune,
|
|
|
|
+ model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
|
|
+
|
|
test_utils.add_tests(ClassifierLoadingTests.load_for_inference,
|
|
test_utils.add_tests(ClassifierLoadingTests.load_for_inference,
|
|
model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
|
|
+
|
|
|
|
+test_utils.add_tests(MeanModelClassifierLoadingTests.load_for_inference,
|
|
|
|
+ model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|