creation.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import test_utils
  2. import unittest
  3. from cvmodelz.models import ModelFactory
  4. from cvmodelz.models.pretrained.base import PretrainedModelMixin
  5. from cvmodelz.models.wrapper import ModelWrapper
  6. from cvmodelz.utils.links.pooling import GlobalAveragePooling
  7. class ModelCreationsTests(unittest.TestCase):
  8. def with_pooling_string(self, key):
  9. model = ModelFactory.new(key, pooling="g_avg")
  10. self.assertIsNotNone(model)
  11. self.assertIsInstance(model.pool, GlobalAveragePooling)
  12. def cv2model_creation(self, key):
  13. model = ModelFactory.new(key)
  14. self.assertIsNotNone(model)
  15. self.assertIsInstance(model, ModelWrapper)
  16. def pretrained_model_creation(self, key):
  17. model = ModelFactory.new(key)
  18. self.assertIsNotNone(model)
  19. self.assertIsInstance(model, PretrainedModelMixin)
  20. test_utils.add_tests(ModelCreationsTests.cv2model_creation,
  21. model_list=ModelFactory.get_models(["chainercv2"]))
  22. test_utils.add_tests(ModelCreationsTests.pretrained_model_creation,
  23. model_list=ModelFactory.get_models(["cvmodelz"]))
  24. test_utils.add_tests(ModelCreationsTests.with_pooling_string,
  25. model_list=ModelFactory.get_models(["cvmodelz"]))