12345678910111213141516171819202122232425262728293031323334353637 |
- import unittest
- import test_utils
- from cvmodelz.models import ModelFactory
- class FactoryTests(unittest.TestCase):
- def model_creation(self, key):
- model = ModelFactory.new(key)
- self.assertIsNotNone(model)
- def cv2model_load(self, key):
- model_rnd = ModelFactory.new(key, pretrained=False)
- model_loaded1 = ModelFactory.new(key, pretrained=True)
- model_loaded2 = ModelFactory.new(key, pretrained=True)
- params_rnd = dict(model_rnd.namedparams())
- params_loaded1 = dict(model_loaded1.namedparams())
- params_loaded2 = dict(model_loaded2.namedparams())
- for name, param in params_rnd.items():
- loaded1 = params_loaded1[name]
- loaded2 = params_loaded2[name]
- self.assertTrue(( param.array != loaded1.array).any())
- self.assertTrue(( param.array != loaded2.array).any())
- self.assertTrue((loaded1.array == loaded2.array).all())
- test_utils.add_tests(FactoryTests.model_creation,
- model_list=ModelFactory.get_all_models())
- test_utils.add_tests(FactoryTests.cv2model_load,
- model_list=ModelFactory.get_models(["chainercv2"]))
|