pretrained_tests.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import io
  2. import numpy as np
  3. import unittest
  4. import test_utils
  5. from chainer.serializers import npz
  6. from contextlib import contextmanager
  7. from cvmodelz.models.pretrained import *
  8. from cvmodelz.models import ModelFactory
  9. class PretrainedTests(unittest.TestCase):
  10. @contextmanager
  11. def mem_file(self) -> io.BytesIO:
  12. yield io.BytesIO()
  13. def load_for_finetune(self, key):
  14. model = ModelFactory.new(key, n_classes=1000)
  15. new_model = ModelFactory.new(key, n_classes=1000)
  16. self.assertTrue(test_utils.is_any_different(model, new_model))
  17. with self.mem_file() as f:
  18. model.save(f)
  19. f.seek(0)
  20. new_model.load_for_finetune(f, n_classes=200, strict=True)
  21. self.assertTrue(test_utils.is_all_equal(model, new_model))
  22. def load_for_inference(self, key):
  23. model = ModelFactory.new(key, n_classes=200)
  24. new_model = ModelFactory.new(key, n_classes=1000)
  25. self.assertTrue(test_utils.is_any_different(model, new_model))
  26. with self.mem_file() as f:
  27. model.save(f)
  28. f.seek(0)
  29. new_model.load_for_inference(f, n_classes=200, strict=True)
  30. self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True))
  31. test_utils.add_tests(PretrainedTests.load_for_finetune,
  32. model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
  33. test_utils.add_tests(PretrainedTests.load_for_inference,
  34. model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))