1234567891011121314151617181920212223242526272829 |
- import test_utils
- import unittest
- from cvmodelz.models import ModelFactory
- from cvmodelz.models.pretrained.base import PretrainedModelMixin
- from cvmodelz.models.wrapper import ModelWrapper
- class ModelCreationsTests(unittest.TestCase):
- def cv2model_creation(self, key):
- model = ModelFactory.new(key)
- self.assertIsNotNone(model)
- self.assertIsInstance(model, ModelWrapper)
- def pretrained_model_creation(self, key):
- model = ModelFactory.new(key)
- self.assertIsNotNone(model)
- self.assertIsInstance(model, PretrainedModelMixin)
- test_utils.add_tests(ModelCreationsTests.cv2model_creation,
- model_list=ModelFactory.get_models(["chainercv2"]))
- test_utils.add_tests(ModelCreationsTests.pretrained_model_creation,
- model_list=ModelFactory.get_models(["cvmodelz"]))
|