12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import io
- import unittest
- import test_utils
- from chainer.serializers import npz
- from contextlib import contextmanager
- from cvmodelz.models.pretrained import *
- from cvmodelz.models import ModelFactory
- class PretrainedTests(unittest.TestCase):
- @contextmanager
- def mem_file(self) -> io.BytesIO:
- yield io.BytesIO()
- def load_for_finetune(self, key):
- model = ModelFactory.new(key)
- new_model = ModelFactory.new(key)
- self.assertTrue(test_utils.is_any_different(model, new_model))
- with self.mem_file() as f:
- npz.save_npz(f, model)
- f.seek(0)
- new_model.load_for_finetune(f, n_classes=200)
- self.assertTrue(test_utils.is_all_equal(model, new_model))
- def load_for_inference(self, key):
- model = ModelFactory.new(key, n_classes=200)
- new_model = ModelFactory.new(key, n_classes=1000)
- self.assertTrue(test_utils.is_any_different(model, new_model))
- with self.mem_file() as f:
- npz.save_npz(f, model)
- f.seek(0)
- new_model.load_for_inference(f, n_classes=200)
- self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True))
- test_utils.add_tests(PretrainedTests.load_for_finetune,
- model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
- test_utils.add_tests(PretrainedTests.load_for_inference,
- model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
|