123456789101112131415161718192021222324252627282930313233343536373839 |
- import test_utils
- import unittest
- from cvmodelz.models import ModelFactory
- from cvmodelz.models.pretrained.base import PretrainedModelMixin
- from cvmodelz.models.wrapper import ModelWrapper
- from cvmodelz.utils.links.pooling import GlobalAveragePooling
- class ModelCreationsTests(unittest.TestCase):
- def with_pooling_string(self, key):
- model = ModelFactory.new(key, pooling="g_avg")
- self.assertIsNotNone(model)
- self.assertIsInstance(model.pool, GlobalAveragePooling)
- def cv2model_creation(self, key):
- model = ModelFactory.new(key)
- self.assertIsNotNone(model)
- self.assertIsInstance(model, ModelWrapper)
- def pretrained_model_creation(self, key):
- model = ModelFactory.new(key)
- self.assertIsNotNone(model)
- self.assertIsInstance(model, PretrainedModelMixin)
- test_utils.add_tests(ModelCreationsTests.cv2model_creation,
- model_list=ModelFactory.get_models(["chainercv2"]))
- test_utils.add_tests(ModelCreationsTests.pretrained_model_creation,
- model_list=ModelFactory.get_models(["cvmodelz"]))
- test_utils.add_tests(ModelCreationsTests.with_pooling_string,
- model_list=ModelFactory.get_models(["cvmodelz"]))
|