import unittest import test_utils from cvmodelz.models import ModelFactory from cvmodelz.models.pretrained.base import PretrainedModelMixin from cvmodelz.models.wrapper import ModelWrapper class FactoryTests(unittest.TestCase): def cv2model_creation(self, key): model = ModelFactory.new(key) self.assertIsNotNone(model) self.assertIsInstance(model, ModelWrapper) def pretrained_model_creation(self, key): model = ModelFactory.new(key) self.assertIsNotNone(model) self.assertIsInstance(model, PretrainedModelMixin) def cv2model_load(self, key): model_rnd = ModelFactory.new(key, pretrained_model=None) model_loaded1 = ModelFactory.new(key, pretrained_model="auto") model_loaded2 = ModelFactory.new(key, pretrained_model="auto") params_rnd = dict(model_rnd.namedparams()) params_loaded1 = dict(model_loaded1.namedparams()) params_loaded2 = dict(model_loaded2.namedparams()) for name, param in params_rnd.items(): loaded1 = params_loaded1[name] loaded2 = params_loaded2[name] self.assertTrue(( param.array != loaded1.array).any()) self.assertTrue(( param.array != loaded2.array).any()) self.assertTrue((loaded1.array == loaded2.array).all()) test_utils.add_tests(FactoryTests.cv2model_creation, model_list=ModelFactory.get_models(["chainercv2"])) test_utils.add_tests(FactoryTests.pretrained_model_creation, model_list=ModelFactory.get_models(["cvmodelz"])) test_utils.add_tests(FactoryTests.cv2model_load, model_list=ModelFactory.get_models(["chainercv2"]))