factory_tests.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import unittest
  2. import test_utils
  3. from cvmodelz.models import ModelFactory
  4. class FactoryTests(unittest.TestCase):
  5. def model_creation(self, key):
  6. model = ModelFactory.new(key)
  7. self.assertIsNotNone(model)
  8. def cv2model_load(self, key):
  9. model_rnd = ModelFactory.new(key, pretrained=False)
  10. model_loaded1 = ModelFactory.new(key, pretrained=True)
  11. model_loaded2 = ModelFactory.new(key, pretrained=True)
  12. params_rnd = dict(model_rnd.namedparams())
  13. params_loaded1 = dict(model_loaded1.namedparams())
  14. params_loaded2 = dict(model_loaded2.namedparams())
  15. for name, param in params_rnd.items():
  16. loaded1 = params_loaded1[name]
  17. loaded2 = params_loaded2[name]
  18. self.assertTrue(( param.array != loaded1.array).any())
  19. self.assertTrue(( param.array != loaded2.array).any())
  20. self.assertTrue((loaded1.array == loaded2.array).all())
  21. test_utils.add_tests(FactoryTests.model_creation,
  22. model_list=ModelFactory.get_all_models())
  23. test_utils.add_tests(FactoryTests.cv2model_load,
  24. model_list=ModelFactory.get_models(["chainercv2"]))