loading.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import io
  2. import numpy as np
  3. import test_utils
  4. import unittest
  5. from cvmodelz.classifiers import Classifier
  6. from cvmodelz.classifiers import MeanModelClassifier
  7. from cvmodelz.models import ModelFactory
  8. class ClassifierLoadingTests(unittest.TestCase):
  9. def new_clf(self, key, **kwargs):
  10. model = ModelFactory.new(key, pretrained_model=None)
  11. return model, Classifier(model, **kwargs)
  12. def load_model(self, key, finetune):
  13. init_cls = 1000 if finetune else 200
  14. final_cls = 200
  15. model = ModelFactory.new(key, pretrained_model=None, n_classes=init_cls)
  16. _, clf = self.new_clf(key)
  17. self.assertTrue(*test_utils.is_any_different(model, clf.model))
  18. with test_utils.memory_file() as f:
  19. model.save(f)
  20. clf.load(f, n_classes=final_cls, finetune=finetune)
  21. """ if finetune is True, then the shapes of the classification
  22. layer differ, hence, strict should be False """
  23. self.assertTrue(*test_utils.is_all_equal(model, clf.model, strict=not finetune))
  24. return model, clf
  25. def load_for_finetune(self, key):
  26. self.load_model(key, finetune=True)
  27. def load_for_inference(self, key):
  28. self.load_model(key, finetune=False)
  29. class MeanModelClassifierLoadingTests(ClassifierLoadingTests):
  30. def new_clf(self, key, **kwargs):
  31. model = ModelFactory.new(key, pretrained_model=None)
  32. return model, MeanModelClassifier(model, **kwargs)
  33. def load_for_finetune(self, key):
  34. _, clf = self.load_model(key, finetune=True)
  35. self.assertTrue(*test_utils.is_all_equal(clf.model, clf.separate_model, strict=True, exclude_clf=True))
  36. def load_for_inference(self, key):
  37. _, clf = self.load_model(key, finetune=False)
  38. self.assertTrue(*test_utils.is_all_equal(clf.model, clf.separate_model, strict=True))
  39. test_utils.add_tests(ClassifierLoadingTests.load_for_finetune,
  40. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
  41. test_utils.add_tests(MeanModelClassifierLoadingTests.load_for_finetune,
  42. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
  43. test_utils.add_tests(ClassifierLoadingTests.load_for_inference,
  44. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
  45. test_utils.add_tests(MeanModelClassifierLoadingTests.load_for_inference,
  46. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))