|
@@ -3,6 +3,7 @@ import test_utils
|
|
import unittest
|
|
import unittest
|
|
|
|
|
|
from cvmodelz.classifiers import Classifier
|
|
from cvmodelz.classifiers import Classifier
|
|
|
|
+from cvmodelz.classifiers import MeanModelClassifier
|
|
from cvmodelz.models import ModelFactory
|
|
from cvmodelz.models import ModelFactory
|
|
|
|
|
|
class ClassifierCreationTests(unittest.TestCase):
|
|
class ClassifierCreationTests(unittest.TestCase):
|
|
@@ -28,10 +29,22 @@ class ClassifierCreationTests(unittest.TestCase):
|
|
self.assertEqual(loss.ndim, 0)
|
|
self.assertEqual(loss.ndim, 0)
|
|
self.assertEqual(loss.shape, ())
|
|
self.assertEqual(loss.shape, ())
|
|
|
|
|
|
|
|
+class MeanModelClassifierCreationTests(ClassifierCreationTests):
|
|
|
|
+
|
|
|
|
+ def new_clf(self, key, **kwargs):
|
|
|
|
+ model = ModelFactory.new(key, pretrained_model=None)
|
|
|
|
+ return model, MeanModelClassifier(model, **kwargs)
|
|
|
|
+
|
|
|
|
|
|
|
|
|
|
test_utils.add_tests(ClassifierCreationTests.creation,
|
|
test_utils.add_tests(ClassifierCreationTests.creation,
|
|
model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
|
|
|
|
|
|
+test_utils.add_tests(MeanModelClassifierCreationTests.creation,
|
|
|
|
+ model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
|
|
+
|
|
test_utils.add_tests(ClassifierCreationTests.loss_computation,
|
|
test_utils.add_tests(ClassifierCreationTests.loss_computation,
|
|
model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
|
|
+
|
|
|
|
+test_utils.add_tests(MeanModelClassifierCreationTests.loss_computation,
|
|
|
|
+ model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|