import io
import test_utils
import unittest

from contextlib import contextmanager

from cvmodelz.models import ModelFactory

class ModelLoadingTests(unittest.TestCase):


	def cv2model_load_pretrained(self, key):

		model_rnd = ModelFactory.new(key, pretrained_model=None)
		model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
		model_loaded2 = ModelFactory.new(key, pretrained_model="auto")

		params_rnd = dict(model_rnd.namedparams())
		params_loaded1 = dict(model_loaded1.namedparams())
		params_loaded2 = dict(model_loaded2.namedparams())

		for name, param in params_rnd.items():
			loaded1 = params_loaded1[name]
			loaded2 = params_loaded2[name]

			self.assertTrue((  param.array != loaded1.array).any())
			self.assertTrue((  param.array != loaded2.array).any())
			self.assertTrue((loaded1.array == loaded2.array).all())


	def load_for_finetune(self, key):

		model = ModelFactory.new(key, n_classes=1000)
		new_model = ModelFactory.new(key, n_classes=1000)

		self.assertTrue(*test_utils.is_any_different(model, new_model))

		with test_utils.memory_file() as f:
			model.save(f)
			new_model.load_for_finetune(f, n_classes=200, strict=True)

		self.assertTrue(*test_utils.is_all_equal(model, new_model))

	def load_for_inference(self, key):

		model = ModelFactory.new(key, n_classes=200)
		new_model = ModelFactory.new(key, n_classes=1000)

		self.assertTrue(*test_utils.is_any_different(model, new_model))

		with test_utils.memory_file() as f:
			model.save(f)
			new_model.load_for_inference(f, n_classes=200, strict=True)

		self.assertTrue(*test_utils.is_all_equal(model, new_model, strict=True))



test_utils.add_tests(ModelLoadingTests.load_for_finetune,
	model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))

test_utils.add_tests(ModelLoadingTests.load_for_inference,
	model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))

test_utils.add_tests(ModelLoadingTests.cv2model_load_pretrained,
	model_list=ModelFactory.get_models(["chainercv2"]))