123456789101112131415161718192021222324252627282930313233343536373839404142 |
- import unittest
- from contextlib import contextmanager
- from cvmodelz.models import ModelFactory
- @contextmanager
- def clear_print(msg):
- print(msg)
- yield
- print("\033[A{}\033[A".format(" "*len(msg)))
- class FactoryTests(unittest.TestCase):
- def test_model_creation(self):
- for key in ModelFactory.get_all_models():
- with clear_print(f"Creating {key}..."):
- model = ModelFactory.new(key)
- self.assertIsNotNone(model)
- def test_cv2model_load(self):
- for key in ModelFactory.get_models(["chainercv2"]):
- with clear_print(f"Loading default weights for {key}..."):
- model_rnd = ModelFactory.new(key, pretrained=False)
- model_loaded1 = ModelFactory.new(key, pretrained=True)
- model_loaded2 = ModelFactory.new(key, pretrained=True)
- params_rnd = dict(model_rnd.namedparams())
- params_loaded1 = dict(model_loaded1.namedparams())
- params_loaded2 = dict(model_loaded2.namedparams())
- for name, param in params_rnd.items():
- loaded1 = params_loaded1[name]
- loaded2 = params_loaded2[name]
- self.assertTrue(( param.array != loaded1.array).any())
- self.assertTrue(( param.array != loaded2.array).any())
- self.assertTrue((loaded1.array == loaded2.array).all())
|