loading.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import io
  2. import numpy as np
  3. import test_utils
  4. import unittest
  5. from cvmodelz.classifiers import Classifier
  6. from cvmodelz.models import ModelFactory
  7. class ClassifierLoadingTests(unittest.TestCase):
  8. def new_clf(self, key, **kwargs):
  9. model = ModelFactory.new(key, pretrained_model=None)
  10. return model, Classifier(model, **kwargs)
  11. def load_model(self, key, finetune):
  12. init_cls = 1000 if finetune else 200
  13. final_cls = 200
  14. model = ModelFactory.new(key, pretrained_model=None, n_classes=init_cls)
  15. _, clf = self.new_clf(key)
  16. self.assertTrue(*test_utils.is_any_different(model, clf.model))
  17. with test_utils.memory_file() as f:
  18. model.save(f)
  19. clf.load(f, n_classes=final_cls, finetune=finetune)
  20. """ if finetune is True, then the shapes of the classification
  21. layer differ, hence, strict should be False """
  22. self.assertTrue(*test_utils.is_all_equal(model, clf.model, strict=not finetune))
  23. def load_for_finetune(self, key):
  24. self.load_model(key, finetune=True)
  25. def load_for_inference(self, key):
  26. self.load_model(key, finetune=False)
  27. test_utils.add_tests(ClassifierLoadingTests.load_for_finetune,
  28. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
  29. test_utils.add_tests(ClassifierLoadingTests.load_for_inference,
  30. model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))