import test_utils import unittest from cvmodelz.models import ModelFactory from cvmodelz.models.pretrained.base import PretrainedModelMixin from cvmodelz.models.wrapper import ModelWrapper from cvmodelz.utils.links.pooling import GlobalAveragePooling class ModelCreationsTests(unittest.TestCase): def with_pooling_string(self, key): model = ModelFactory.new(key, pooling="g_avg") self.assertIsNotNone(model) self.assertIsInstance(model.pool, GlobalAveragePooling) def cv2model_creation(self, key): model = ModelFactory.new(key) self.assertIsNotNone(model) self.assertIsInstance(model, ModelWrapper) def pretrained_model_creation(self, key): model = ModelFactory.new(key) self.assertIsNotNone(model) self.assertIsInstance(model, PretrainedModelMixin) test_utils.add_tests(ModelCreationsTests.cv2model_creation, model_list=ModelFactory.get_models(["chainercv2"])) test_utils.add_tests(ModelCreationsTests.pretrained_model_creation, model_list=ModelFactory.get_models(["cvmodelz"])) test_utils.add_tests(ModelCreationsTests.with_pooling_string, model_list=ModelFactory.get_models(["cvmodelz"]))