factory_tests.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import unittest
  2. import test_utils
  3. from cvmodelz.models import ModelFactory
  4. from cvmodelz.models.pretrained.base import PretrainedModelMixin
  5. from cvmodelz.models.wrapper import ModelWrapper
  6. class FactoryTests(unittest.TestCase):
  7. def cv2model_creation(self, key):
  8. model = ModelFactory.new(key)
  9. self.assertIsNotNone(model)
  10. self.assertIsInstance(model, ModelWrapper)
  11. def pretrained_model_creation(self, key):
  12. model = ModelFactory.new(key)
  13. self.assertIsNotNone(model)
  14. self.assertIsInstance(model, PretrainedModelMixin)
  15. def cv2model_load(self, key):
  16. model_rnd = ModelFactory.new(key, pretrained_model=None)
  17. model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
  18. model_loaded2 = ModelFactory.new(key, pretrained_model="auto")
  19. params_rnd = dict(model_rnd.namedparams())
  20. params_loaded1 = dict(model_loaded1.namedparams())
  21. params_loaded2 = dict(model_loaded2.namedparams())
  22. for name, param in params_rnd.items():
  23. loaded1 = params_loaded1[name]
  24. loaded2 = params_loaded2[name]
  25. self.assertTrue(( param.array != loaded1.array).any())
  26. self.assertTrue(( param.array != loaded2.array).any())
  27. self.assertTrue((loaded1.array == loaded2.array).all())
  28. test_utils.add_tests(FactoryTests.cv2model_creation,
  29. model_list=ModelFactory.get_models(["chainercv2"]))
  30. test_utils.add_tests(FactoryTests.pretrained_model_creation,
  31. model_list=ModelFactory.get_models(["cvmodelz"]))
  32. test_utils.add_tests(FactoryTests.cv2model_load,
  33. model_list=ModelFactory.get_models(["chainercv2"]))