1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- import io
- import numpy as np
- import unittest
- import test_utils
- from chainer.serializers import npz
- from contextlib import contextmanager
- from cvmodelz.models.pretrained import *
- from cvmodelz.models import ModelFactory
- class PretrainedTests(unittest.TestCase):
- @contextmanager
- def mem_file(self) -> io.BytesIO:
- yield io.BytesIO()
- def load_for_finetune(self, key):
- model = ModelFactory.new(key, n_classes=1000)
- new_model = ModelFactory.new(key, n_classes=1000)
- self.assertTrue(test_utils.is_any_different(model, new_model))
- with self.mem_file() as f:
- model.save(f)
- f.seek(0)
- new_model.load_for_finetune(f, n_classes=200, strict=True)
- self.assertTrue(test_utils.is_all_equal(model, new_model))
- def load_for_inference(self, key):
- model = ModelFactory.new(key, n_classes=200)
- new_model = ModelFactory.new(key, n_classes=1000)
- self.assertTrue(test_utils.is_any_different(model, new_model))
- with self.mem_file() as f:
- model.save(f)
- f.seek(0)
- new_model.load_for_inference(f, n_classes=200, strict=True)
- self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True))
- test_utils.add_tests(PretrainedTests.load_for_finetune,
- model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
- test_utils.add_tests(PretrainedTests.load_for_inference,
- model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
|