|
@@ -0,0 +1,37 @@
|
|
|
+import numpy as np
|
|
|
+import test_utils
|
|
|
+import unittest
|
|
|
+
|
|
|
+from cvmodelz.classifiers import Classifier
|
|
|
+from cvmodelz.models import ModelFactory
|
|
|
+
|
|
|
+class ClassifierTests(unittest.TestCase):
|
|
|
+
|
|
|
+
|
|
|
+ def new_clf(self, key, **kwargs):
|
|
|
+ model = ModelFactory.new(key, pretrained_model=None)
|
|
|
+ return model, Classifier(model, **kwargs)
|
|
|
+
|
|
|
+ def creation(self, key):
|
|
|
+ model, clf = self.new_clf(key)
|
|
|
+ self.assertIs(clf.model, model)
|
|
|
+
|
|
|
+ def loss_computation(self, key):
|
|
|
+ model, clf = self.new_clf(key)
|
|
|
+
|
|
|
+ in_size = clf.model.meta.input_size
|
|
|
+ X = clf.xp.ones((4, 3, in_size, in_size), dtype=np.float32)
|
|
|
+ y = clf.xp.random.choice(clf.n_classes, size=4)
|
|
|
+
|
|
|
+ loss = clf(X, y)
|
|
|
+ self.assertIsNotNone(loss)
|
|
|
+ self.assertEqual(loss.ndim, 0)
|
|
|
+ self.assertEqual(loss.shape, ())
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+test_utils.add_tests(ClassifierTests.creation,
|
|
|
+ model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|
|
|
+
|
|
|
+test_utils.add_tests(ClassifierTests.loss_computation,
|
|
|
+ model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
|