|
@@ -2,21 +2,30 @@ import unittest
|
|
|
import test_utils
|
|
|
|
|
|
from cvmodelz.models import ModelFactory
|
|
|
-
|
|
|
+from cvmodelz.models.pretrained.base import PretrainedModelMixin
|
|
|
+from cvmodelz.models.wrapper import ModelWrapper
|
|
|
|
|
|
|
|
|
class FactoryTests(unittest.TestCase):
|
|
|
|
|
|
- def model_creation(self, key):
|
|
|
+ def cv2model_creation(self, key):
|
|
|
+
|
|
|
+ model = ModelFactory.new(key)
|
|
|
+ self.assertIsNotNone(model)
|
|
|
+
|
|
|
+ self.assertIsInstance(model, ModelWrapper)
|
|
|
|
|
|
+ def pretrained_model_creation(self, key):
|
|
|
model = ModelFactory.new(key)
|
|
|
self.assertIsNotNone(model)
|
|
|
|
|
|
+ self.assertIsInstance(model, PretrainedModelMixin)
|
|
|
+
|
|
|
def cv2model_load(self, key):
|
|
|
|
|
|
- model_rnd = ModelFactory.new(key, pretrained=False)
|
|
|
- model_loaded1 = ModelFactory.new(key, pretrained=True)
|
|
|
- model_loaded2 = ModelFactory.new(key, pretrained=True)
|
|
|
+ model_rnd = ModelFactory.new(key, pretrained_model=None)
|
|
|
+ model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
|
|
|
+ model_loaded2 = ModelFactory.new(key, pretrained_model="auto")
|
|
|
|
|
|
params_rnd = dict(model_rnd.namedparams())
|
|
|
params_loaded1 = dict(model_loaded1.namedparams())
|
|
@@ -30,8 +39,12 @@ class FactoryTests(unittest.TestCase):
|
|
|
self.assertTrue(( param.array != loaded2.array).any())
|
|
|
self.assertTrue((loaded1.array == loaded2.array).all())
|
|
|
|
|
|
-test_utils.add_tests(FactoryTests.model_creation,
|
|
|
- model_list=ModelFactory.get_all_models())
|
|
|
+
|
|
|
+test_utils.add_tests(FactoryTests.cv2model_creation,
|
|
|
+ model_list=ModelFactory.get_models(["chainercv2"]))
|
|
|
+
|
|
|
+test_utils.add_tests(FactoryTests.pretrained_model_creation,
|
|
|
+ model_list=ModelFactory.get_models(["cvmodelz"]))
|
|
|
|
|
|
test_utils.add_tests(FactoryTests.cv2model_load,
|
|
|
model_list=ModelFactory.get_models(["chainercv2"]))
|