loading.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import io
  2. import test_utils
  3. import unittest
  4. from contextlib import contextmanager
  5. from cvmodelz.models import ModelFactory
  6. class ModelLoadingTests(unittest.TestCase):
  7. def cv2model_load_pretrained(self, key):
  8. model_rnd = ModelFactory.new(key, pretrained_model=None)
  9. model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
  10. model_loaded2 = ModelFactory.new(key, pretrained_model="auto")
  11. params_rnd = dict(model_rnd.namedparams())
  12. params_loaded1 = dict(model_loaded1.namedparams())
  13. params_loaded2 = dict(model_loaded2.namedparams())
  14. for name, param in params_rnd.items():
  15. loaded1 = params_loaded1[name]
  16. loaded2 = params_loaded2[name]
  17. self.assertTrue(( param.array != loaded1.array).any())
  18. self.assertTrue(( param.array != loaded2.array).any())
  19. self.assertTrue((loaded1.array == loaded2.array).all())
  20. def load_for_finetune(self, key):
  21. model = ModelFactory.new(key, n_classes=1000)
  22. new_model = ModelFactory.new(key, n_classes=1000)
  23. self.assertTrue(*test_utils.is_any_different(model, new_model))
  24. with test_utils.memory_file() as f:
  25. model.save(f)
  26. new_model.load_for_finetune(f, n_classes=200, strict=True)
  27. self.assertTrue(*test_utils.is_all_equal(model, new_model))
  28. def load_for_inference(self, key):
  29. model = ModelFactory.new(key, n_classes=200)
  30. new_model = ModelFactory.new(key, n_classes=1000)
  31. self.assertTrue(*test_utils.is_any_different(model, new_model))
  32. with test_utils.memory_file() as f:
  33. model.save(f)
  34. new_model.load_for_inference(f, n_classes=200, strict=True)
  35. self.assertTrue(*test_utils.is_all_equal(model, new_model, strict=True))
  36. test_utils.add_tests(ModelLoadingTests.load_for_finetune,
  37. model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
  38. test_utils.add_tests(ModelLoadingTests.load_for_inference,
  39. model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
  40. test_utils.add_tests(ModelLoadingTests.cv2model_load_pretrained,
  41. model_list=ModelFactory.get_models(["chainercv2"]))