import unittest from contextlib import contextmanager from cvmodelz.models import ModelFactory @contextmanager def clear_print(msg): print(msg) yield print("\033[A{}\033[A".format(" "*len(msg))) class FactoryTests(unittest.TestCase): def test_model_creation(self): for key in ModelFactory.get_all_models(): with clear_print(f"Creating {key}..."): model = ModelFactory.new(key) self.assertIsNotNone(model) def test_cv2model_load(self): for key in ModelFactory.get_models(["chainercv2"]): with clear_print(f"Loading default weights for {key}..."): model_rnd = ModelFactory.new(key, pretrained=False) model_loaded1 = ModelFactory.new(key, pretrained=True) model_loaded2 = ModelFactory.new(key, pretrained=True) params_rnd = dict(model_rnd.namedparams()) params_loaded1 = dict(model_loaded1.namedparams()) params_loaded2 = dict(model_loaded2.namedparams()) for name, param in params_rnd.items(): loaded1 = params_loaded1[name] loaded2 = params_loaded2[name] self.assertTrue(( param.array != loaded1.array).any()) self.assertTrue(( param.array != loaded2.array).any()) self.assertTrue((loaded1.array == loaded2.array).all())