pretrained_tests.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import io
  2. import unittest
  3. import test_utils
  4. from chainer.serializers import npz
  5. from contextlib import contextmanager
  6. from cvmodelz.models.pretrained import *
  7. from cvmodelz.models import ModelFactory
  8. class PretrainedTests(unittest.TestCase):
  9. @contextmanager
  10. def mem_file(self) -> io.BytesIO:
  11. yield io.BytesIO()
  12. def load_for_finetune(self, key):
  13. model = ModelFactory.new(key)
  14. new_model = ModelFactory.new(key)
  15. self.assertTrue(test_utils.is_any_different(model, new_model))
  16. with self.mem_file() as f:
  17. npz.save_npz(f, model)
  18. f.seek(0)
  19. new_model.load_for_finetune(f, n_classes=200)
  20. self.assertTrue(test_utils.is_all_equal(model, new_model))
  21. def load_for_inference(self, key):
  22. model = ModelFactory.new(key, n_classes=200)
  23. new_model = ModelFactory.new(key, n_classes=1000)
  24. self.assertTrue(test_utils.is_any_different(model, new_model))
  25. with self.mem_file() as f:
  26. npz.save_npz(f, model)
  27. f.seek(0)
  28. new_model.load_for_inference(f, n_classes=200)
  29. self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True))
  30. test_utils.add_tests(PretrainedTests.load_for_finetune,
  31. model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
  32. test_utils.add_tests(PretrainedTests.load_for_inference,
  33. model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))