|
@@ -1,42 +1,37 @@
|
|
|
import unittest
|
|
|
-
|
|
|
-from contextlib import contextmanager
|
|
|
+import test_utils
|
|
|
|
|
|
from cvmodelz.models import ModelFactory
|
|
|
|
|
|
|
|
|
-@contextmanager
|
|
|
-def clear_print(msg):
|
|
|
- print(msg)
|
|
|
- yield
|
|
|
- print("\033[A{}\033[A".format(" "*len(msg)))
|
|
|
|
|
|
class FactoryTests(unittest.TestCase):
|
|
|
|
|
|
- def test_model_creation(self):
|
|
|
- for key in ModelFactory.get_all_models():
|
|
|
+ def model_creation(self, key):
|
|
|
+
|
|
|
+ model = ModelFactory.new(key)
|
|
|
+ self.assertIsNotNone(model)
|
|
|
|
|
|
- with clear_print(f"Creating {key}..."):
|
|
|
- model = ModelFactory.new(key)
|
|
|
+ def cv2model_load(self, key):
|
|
|
|
|
|
- self.assertIsNotNone(model)
|
|
|
+ model_rnd = ModelFactory.new(key, pretrained=False)
|
|
|
+ model_loaded1 = ModelFactory.new(key, pretrained=True)
|
|
|
+ model_loaded2 = ModelFactory.new(key, pretrained=True)
|
|
|
|
|
|
- def test_cv2model_load(self):
|
|
|
- for key in ModelFactory.get_models(["chainercv2"]):
|
|
|
- with clear_print(f"Loading default weights for {key}..."):
|
|
|
+ params_rnd = dict(model_rnd.namedparams())
|
|
|
+ params_loaded1 = dict(model_loaded1.namedparams())
|
|
|
+ params_loaded2 = dict(model_loaded2.namedparams())
|
|
|
|
|
|
- model_rnd = ModelFactory.new(key, pretrained=False)
|
|
|
- model_loaded1 = ModelFactory.new(key, pretrained=True)
|
|
|
- model_loaded2 = ModelFactory.new(key, pretrained=True)
|
|
|
+ for name, param in params_rnd.items():
|
|
|
+ loaded1 = params_loaded1[name]
|
|
|
+ loaded2 = params_loaded2[name]
|
|
|
|
|
|
- params_rnd = dict(model_rnd.namedparams())
|
|
|
- params_loaded1 = dict(model_loaded1.namedparams())
|
|
|
- params_loaded2 = dict(model_loaded2.namedparams())
|
|
|
+ self.assertTrue(( param.array != loaded1.array).any())
|
|
|
+ self.assertTrue(( param.array != loaded2.array).any())
|
|
|
+ self.assertTrue((loaded1.array == loaded2.array).all())
|
|
|
|
|
|
- for name, param in params_rnd.items():
|
|
|
- loaded1 = params_loaded1[name]
|
|
|
- loaded2 = params_loaded2[name]
|
|
|
+test_utils.add_tests(FactoryTests.model_creation,
|
|
|
+ model_list=ModelFactory.get_all_models())
|
|
|
|
|
|
- self.assertTrue(( param.array != loaded1.array).any())
|
|
|
- self.assertTrue(( param.array != loaded2.array).any())
|
|
|
- self.assertTrue((loaded1.array == loaded2.array).all())
|
|
|
+test_utils.add_tests(FactoryTests.cv2model_load,
|
|
|
+ model_list=ModelFactory.get_models(["chainercv2"]))
|