import io import test_utils import unittest from contextlib import contextmanager from cvmodelz.models import ModelFactory class ModelLoadingTests(unittest.TestCase): def cv2model_load_pretrained(self, key): model_rnd = ModelFactory.new(key, pretrained_model=None) model_loaded1 = ModelFactory.new(key, pretrained_model="auto") model_loaded2 = ModelFactory.new(key, pretrained_model="auto") params_rnd = dict(model_rnd.namedparams()) params_loaded1 = dict(model_loaded1.namedparams()) params_loaded2 = dict(model_loaded2.namedparams()) for name, param in params_rnd.items(): loaded1 = params_loaded1[name] loaded2 = params_loaded2[name] self.assertTrue(( param.array != loaded1.array).any()) self.assertTrue(( param.array != loaded2.array).any()) self.assertTrue((loaded1.array == loaded2.array).all()) def load_for_finetune(self, key): model = ModelFactory.new(key, n_classes=1000) new_model = ModelFactory.new(key, n_classes=1000) self.assertTrue(*test_utils.is_any_different(model, new_model)) with test_utils.memory_file() as f: model.save(f) f.seek(0) new_model.load_for_finetune(f, n_classes=200, strict=True) self.assertTrue(*test_utils.is_all_equal(model, new_model)) def load_for_inference(self, key): model = ModelFactory.new(key, n_classes=200) new_model = ModelFactory.new(key, n_classes=1000) self.assertTrue(*test_utils.is_any_different(model, new_model)) with test_utils.memory_file() as f: model.save(f) f.seek(0) new_model.load_for_inference(f, n_classes=200, strict=True) self.assertTrue(*test_utils.is_all_equal(model, new_model, strict=True)) test_utils.add_tests(ModelLoadingTests.load_for_finetune, model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"])) test_utils.add_tests(ModelLoadingTests.load_for_inference, model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"])) test_utils.add_tests(ModelLoadingTests.cv2model_load_pretrained, model_list=ModelFactory.get_models(["chainercv2"]))