import unittest import test_utils from cvmodelz.models import ModelFactory class FactoryTests(unittest.TestCase): def model_creation(self, key): model = ModelFactory.new(key) self.assertIsNotNone(model) def cv2model_load(self, key): model_rnd = ModelFactory.new(key, pretrained=False) model_loaded1 = ModelFactory.new(key, pretrained=True) model_loaded2 = ModelFactory.new(key, pretrained=True) params_rnd = dict(model_rnd.namedparams()) params_loaded1 = dict(model_loaded1.namedparams()) params_loaded2 = dict(model_loaded2.namedparams()) for name, param in params_rnd.items(): loaded1 = params_loaded1[name] loaded2 = params_loaded2[name] self.assertTrue(( param.array != loaded1.array).any()) self.assertTrue(( param.array != loaded2.array).any()) self.assertTrue((loaded1.array == loaded2.array).all()) test_utils.add_tests(FactoryTests.model_creation, model_list=ModelFactory.get_all_models()) test_utils.add_tests(FactoryTests.cv2model_load, model_list=ModelFactory.get_models(["chainercv2"]))