123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import io
- import test_utils
- import unittest
- from contextlib import contextmanager
- from cvmodelz.models import ModelFactory
- class ModelLoadingTests(unittest.TestCase):
- def cv2model_load_pretrained(self, key):
- model_rnd = ModelFactory.new(key, pretrained_model=None)
- model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
- model_loaded2 = ModelFactory.new(key, pretrained_model="auto")
- params_rnd = dict(model_rnd.namedparams())
- params_loaded1 = dict(model_loaded1.namedparams())
- params_loaded2 = dict(model_loaded2.namedparams())
- for name, param in params_rnd.items():
- loaded1 = params_loaded1[name]
- loaded2 = params_loaded2[name]
- self.assertTrue(( param.array != loaded1.array).any())
- self.assertTrue(( param.array != loaded2.array).any())
- self.assertTrue((loaded1.array == loaded2.array).all())
- def load_for_finetune(self, key):
- model = ModelFactory.new(key, n_classes=1000)
- new_model = ModelFactory.new(key, n_classes=1000)
- self.assertTrue(*test_utils.is_any_different(model, new_model))
- with test_utils.memory_file() as f:
- model.save(f)
- new_model.load_for_finetune(f, n_classes=200, strict=True)
- self.assertTrue(*test_utils.is_all_equal(model, new_model))
- def load_for_inference(self, key):
- model = ModelFactory.new(key, n_classes=200)
- new_model = ModelFactory.new(key, n_classes=1000)
- self.assertTrue(*test_utils.is_any_different(model, new_model))
- with test_utils.memory_file() as f:
- model.save(f)
- new_model.load_for_inference(f, n_classes=200, strict=True)
- self.assertTrue(*test_utils.is_all_equal(model, new_model, strict=True))
- test_utils.add_tests(ModelLoadingTests.load_for_finetune,
- model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
- test_utils.add_tests(ModelLoadingTests.load_for_inference,
- model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
- test_utils.add_tests(ModelLoadingTests.cv2model_load_pretrained,
- model_list=ModelFactory.get_models(["chainercv2"]))
|