1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- import unittest
- import test_utils
- from cvmodelz.models import ModelFactory
- from cvmodelz.models.pretrained.base import PretrainedModelMixin
- from cvmodelz.models.wrapper import ModelWrapper
- class FactoryTests(unittest.TestCase):
- def cv2model_creation(self, key):
- model = ModelFactory.new(key)
- self.assertIsNotNone(model)
- self.assertIsInstance(model, ModelWrapper)
- def pretrained_model_creation(self, key):
- model = ModelFactory.new(key)
- self.assertIsNotNone(model)
- self.assertIsInstance(model, PretrainedModelMixin)
- def cv2model_load(self, key):
- model_rnd = ModelFactory.new(key, pretrained_model=None)
- model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
- model_loaded2 = ModelFactory.new(key, pretrained_model="auto")
- params_rnd = dict(model_rnd.namedparams())
- params_loaded1 = dict(model_loaded1.namedparams())
- params_loaded2 = dict(model_loaded2.namedparams())
- for name, param in params_rnd.items():
- loaded1 = params_loaded1[name]
- loaded2 = params_loaded2[name]
- self.assertTrue(( param.array != loaded1.array).any())
- self.assertTrue(( param.array != loaded2.array).any())
- self.assertTrue((loaded1.array == loaded2.array).all())
- test_utils.add_tests(FactoryTests.cv2model_creation,
- model_list=ModelFactory.get_models(["chainercv2"]))
- test_utils.add_tests(FactoryTests.pretrained_model_creation,
- model_list=ModelFactory.get_models(["cvmodelz"]))
- test_utils.add_tests(FactoryTests.cv2model_load,
- model_list=ModelFactory.get_models(["chainercv2"]))
|