import test_utils
import unittest

from cvmodelz.models import ModelFactory
from cvmodelz.models.pretrained.base import PretrainedModelMixin
from cvmodelz.models.wrapper import ModelWrapper
from cvmodelz.utils.links.pooling import GlobalAveragePooling


class ModelCreationsTests(unittest.TestCase):

	def with_pooling_string(self, key):
		model = ModelFactory.new(key, pooling="g_avg")
		self.assertIsNotNone(model)

		self.assertIsInstance(model.pool, GlobalAveragePooling)

	def cv2model_creation(self, key):

		model = ModelFactory.new(key)
		self.assertIsNotNone(model)

		self.assertIsInstance(model, ModelWrapper)

	def pretrained_model_creation(self, key):
		model = ModelFactory.new(key)
		self.assertIsNotNone(model)

		self.assertIsInstance(model, PretrainedModelMixin)

test_utils.add_tests(ModelCreationsTests.cv2model_creation,
	model_list=ModelFactory.get_models(["chainercv2"]))

test_utils.add_tests(ModelCreationsTests.pretrained_model_creation,
	model_list=ModelFactory.get_models(["cvmodelz"]))

test_utils.add_tests(ModelCreationsTests.with_pooling_string,
	model_list=ModelFactory.get_models(["cvmodelz"]))