import io import unittest import test_utils from chainer.serializers import npz from contextlib import contextmanager from cvmodelz.models.pretrained import * from cvmodelz.models import ModelFactory class PretrainedTests(unittest.TestCase): @contextmanager def mem_file(self) -> io.BytesIO: yield io.BytesIO() def load_for_finetune(self, key): model = ModelFactory.new(key) new_model = ModelFactory.new(key) self.assertTrue(test_utils.is_any_different(model, new_model)) with self.mem_file() as f: npz.save_npz(f, model) f.seek(0) new_model.load_for_finetune(f, n_classes=200) self.assertTrue(test_utils.is_all_equal(model, new_model)) def load_for_inference(self, key): model = ModelFactory.new(key, n_classes=200) new_model = ModelFactory.new(key, n_classes=1000) self.assertTrue(test_utils.is_any_different(model, new_model)) with self.mem_file() as f: npz.save_npz(f, model) f.seek(0) new_model.load_for_inference(f, n_classes=200) self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True)) test_utils.add_tests(PretrainedTests.load_for_finetune, model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"])) test_utils.add_tests(PretrainedTests.load_for_inference, model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))