loading.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import io
  2. import test_utils
  3. import unittest
  4. from contextlib import contextmanager
  5. from cvmodelz.models import ModelFactory
  6. class LoadingTests(unittest.TestCase):
  7. @contextmanager
  8. def mem_file(self) -> io.BytesIO:
  9. yield io.BytesIO()
  10. def cv2model_load_pretrained(self, key):
  11. model_rnd = ModelFactory.new(key, pretrained_model=None)
  12. model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
  13. model_loaded2 = ModelFactory.new(key, pretrained_model="auto")
  14. params_rnd = dict(model_rnd.namedparams())
  15. params_loaded1 = dict(model_loaded1.namedparams())
  16. params_loaded2 = dict(model_loaded2.namedparams())
  17. for name, param in params_rnd.items():
  18. loaded1 = params_loaded1[name]
  19. loaded2 = params_loaded2[name]
  20. self.assertTrue(( param.array != loaded1.array).any())
  21. self.assertTrue(( param.array != loaded2.array).any())
  22. self.assertTrue((loaded1.array == loaded2.array).all())
  23. def load_for_finetune(self, key):
  24. model = ModelFactory.new(key, n_classes=1000)
  25. new_model = ModelFactory.new(key, n_classes=1000)
  26. self.assertTrue(test_utils.is_any_different(model, new_model))
  27. with self.mem_file() as f:
  28. model.save(f)
  29. f.seek(0)
  30. new_model.load_for_finetune(f, n_classes=200, strict=True)
  31. self.assertTrue(test_utils.is_all_equal(model, new_model))
  32. def load_for_inference(self, key):
  33. model = ModelFactory.new(key, n_classes=200)
  34. new_model = ModelFactory.new(key, n_classes=1000)
  35. self.assertTrue(test_utils.is_any_different(model, new_model))
  36. with self.mem_file() as f:
  37. model.save(f)
  38. f.seek(0)
  39. new_model.load_for_inference(f, n_classes=200, strict=True)
  40. self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True))
  41. test_utils.add_tests(LoadingTests.load_for_finetune,
  42. model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
  43. test_utils.add_tests(LoadingTests.load_for_inference,
  44. model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
  45. test_utils.add_tests(LoadingTests.cv2model_load_pretrained,
  46. model_list=ModelFactory.get_models(["chainercv2"]))