factory_tests.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import unittest
  2. from contextlib import contextmanager
  3. from cvmodelz.models import ModelFactory
  4. @contextmanager
  5. def clear_print(msg):
  6. print(msg)
  7. yield
  8. print("\033[A{}\033[A".format(" "*len(msg)))
  9. class FactoryTests(unittest.TestCase):
  10. def test_model_creation(self):
  11. for key in ModelFactory.get_all_models():
  12. with clear_print(f"Creating {key}..."):
  13. model = ModelFactory.new(key)
  14. self.assertIsNotNone(model)
  15. def test_cv2model_load(self):
  16. for key in ModelFactory.get_models(["chainercv2"]):
  17. with clear_print(f"Loading default weights for {key}..."):
  18. model_rnd = ModelFactory.new(key, pretrained=False)
  19. model_loaded1 = ModelFactory.new(key, pretrained=True)
  20. model_loaded2 = ModelFactory.new(key, pretrained=True)
  21. params_rnd = dict(model_rnd.namedparams())
  22. params_loaded1 = dict(model_loaded1.namedparams())
  23. params_loaded2 = dict(model_loaded2.namedparams())
  24. for name, param in params_rnd.items():
  25. loaded1 = params_loaded1[name]
  26. loaded2 = params_loaded2[name]
  27. self.assertTrue(( param.array != loaded1.array).any())
  28. self.assertTrue(( param.array != loaded2.array).any())
  29. self.assertTrue((loaded1.array == loaded2.array).all())