Browse Source

fixed some loading tests

Dimitri Korsch 4 years ago
parent
commit
de4307bc23

+ 5 - 1
cvmodelz/models/base.py

@@ -15,7 +15,8 @@ from cvmodelz.models.meta_info import ModelInfo
 
 class BaseModel(abc.ABC, chainer.Chain):
 
-	def __init__(self, pooling: Callable = PoolingType.G_AVG.value(), input_size=None, *args, **kwargs):
+	def __init__(self, pooling: Callable = PoolingType.G_AVG.value(),
+		input_size=None, *args, **kwargs):
 		super(BaseModel, self).__init__(*args, **kwargs)
 		self.init_model_info()
 
@@ -98,3 +99,6 @@ class BaseModel(abc.ABC, chainer.Chain):
 
 			npz.load_npz(weights, self.model_instance,
 				path=path, strict=strict, ignore_names=ignore_names)
+
+	def save(self, path, *args, **kwargs):
+		npz.save_npz(path, self, *args, **kwargs)

+ 8 - 10
cvmodelz/models/factory.py

@@ -50,7 +50,7 @@ class ModelFactory(abc.ABC):
 		raise NotImplementedError("instance creation is not supported!")
 
 	@classmethod
-	def new(cls, model_type, *args, **kwargs):
+	def new(cls, model_type, **kwargs):
 
 		key, cls_name = model_type.split(".")
 
@@ -60,16 +60,14 @@ class ModelFactory(abc.ABC):
 		else:
 			raise ValueError(f"Could not find {model_type}!")
 
-		if model_cls in cls.supported["chainer"]:
-			kwargs["pretrained_model"] = kwargs.get("pretrained_model", None)
-			kwargs.pop("input_size", None)
+		if model_cls in cls.supported["chainercv2"]:
+			n_classes = kwargs.pop("n_classes", 1000)
+			pretrained = kwargs.pop("pretrained_model", None) == "auto"
+			model = model_cls(classes=n_classes, pretrained=pretrained)
+			kwargs["model"] = model
+			model_cls = ModelWrapper
 
-		elif model_cls in cls.supported["chainercv2"]:
-			kwargs["pretrained"] = kwargs.get("pretrained", False)
-			input_size = kwargs.pop("input_size", None)
-			return ModelWrapper(model_cls(*args, **kwargs), input_size=input_size)
-
-		return model_cls(*args, **kwargs)
+		return model_cls(**kwargs)
 
 
 	@classmethod

+ 1 - 1
cvmodelz/models/pretrained/__init__.py

@@ -1,5 +1,5 @@
 from cvmodelz.models.pretrained.base import PretrainedModelMixin
-from cvmodelz.models.pretrained.inception.inception_v3 import InceptionV3
+from cvmodelz.models.pretrained.inception import InceptionV3
 from cvmodelz.models.pretrained.resnet import ResNet101
 from cvmodelz.models.pretrained.resnet import ResNet152
 from cvmodelz.models.pretrained.resnet import ResNet35

+ 6 - 0
cvmodelz/models/pretrained/inception/__init__.py

@@ -0,0 +1,6 @@
+from cvmodelz.models.pretrained.inception.inception_v3 import InceptionV3
+
+
+__all__ = [
+	"InceptionV3",
+]

+ 2 - 2
cvmodelz/models/pretrained/inception/inception_v3.py

@@ -99,9 +99,9 @@ class InceptionV3(PretrainedModelMixin, chainer.Chain):
 
 
 	def load(self, weights, *args, **kwargs):
-		if weights.endswith(".h5"):
+		if isinstance(weights, str) and weights.endswith(".h5"):
 			self._load_from_keras(weights)
-		elif weights.endswith(".ckpt.npz"):
+		elif isinstance(weights, str) and weights.endswith(".ckpt.npz"):
 			self._load_from_ckpt_weights(weights)
 		else:
 			return super(InceptionV3, self).load(weights, *args, **kwargs)

+ 9 - 2
cvmodelz/models/wrapper.py

@@ -53,8 +53,15 @@ class ModelWrapper(BaseModel):
 
 		return OrderedDict(links)
 
-	def load_for_inference(self, *args, path="", **kwargs):
-		return super(ModelWrapper, self).load_for_inference(*args, path=f"{path}wrapped/", **kwargs)
+	def load(self, *args, path="", **kwargs):
+		paths = [path, f"{path}wrapped/"]
+		for _path in paths:
+			try:
+				return super(ModelWrapper, self).load(*args, path=_path, **kwargs)
+			except KeyError as e:
+				pass
+
+		raise RuntimeError(f"tried to load weights with paths {paths}, but did not succeeed")
 
 	def __call__(self, X, layer_name=None):
 		if layer_name is None:

+ 20 - 7
tests/model_tests/factory_tests.py

@@ -2,21 +2,30 @@ import unittest
 import test_utils
 
 from cvmodelz.models import ModelFactory
-
+from cvmodelz.models.pretrained.base import PretrainedModelMixin
+from cvmodelz.models.wrapper import ModelWrapper
 
 
 class FactoryTests(unittest.TestCase):
 
-	def model_creation(self, key):
+	def cv2model_creation(self, key):
+
+		model = ModelFactory.new(key)
+		self.assertIsNotNone(model)
+
+		self.assertIsInstance(model, ModelWrapper)
 
+	def pretrained_model_creation(self, key):
 		model = ModelFactory.new(key)
 		self.assertIsNotNone(model)
 
+		self.assertIsInstance(model, PretrainedModelMixin)
+
 	def cv2model_load(self, key):
 
-		model_rnd = ModelFactory.new(key, pretrained=False)
-		model_loaded1 = ModelFactory.new(key, pretrained=True)
-		model_loaded2 = ModelFactory.new(key, pretrained=True)
+		model_rnd = ModelFactory.new(key, pretrained_model=None)
+		model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
+		model_loaded2 = ModelFactory.new(key, pretrained_model="auto")
 
 		params_rnd = dict(model_rnd.namedparams())
 		params_loaded1 = dict(model_loaded1.namedparams())
@@ -30,8 +39,12 @@ class FactoryTests(unittest.TestCase):
 			self.assertTrue((  param.array != loaded2.array).any())
 			self.assertTrue((loaded1.array == loaded2.array).all())
 
-test_utils.add_tests(FactoryTests.model_creation,
-	model_list=ModelFactory.get_all_models())
+
+test_utils.add_tests(FactoryTests.cv2model_creation,
+	model_list=ModelFactory.get_models(["chainercv2"]))
+
+test_utils.add_tests(FactoryTests.pretrained_model_creation,
+	model_list=ModelFactory.get_models(["cvmodelz"]))
 
 test_utils.add_tests(FactoryTests.cv2model_load,
 	model_list=ModelFactory.get_models(["chainercv2"]))

+ 7 - 6
tests/model_tests/pretrained_tests.py

@@ -1,4 +1,5 @@
 import io
+import numpy as np
 import unittest
 import test_utils
 
@@ -17,15 +18,15 @@ class PretrainedTests(unittest.TestCase):
 
 	def load_for_finetune(self, key):
 
-		model = ModelFactory.new(key)
-		new_model = ModelFactory.new(key)
+		model = ModelFactory.new(key, n_classes=1000)
+		new_model = ModelFactory.new(key, n_classes=1000)
 
 		self.assertTrue(test_utils.is_any_different(model, new_model))
 
 		with self.mem_file() as f:
-			npz.save_npz(f, model)
+			model.save(f)
 			f.seek(0)
-			new_model.load_for_finetune(f, n_classes=200)
+			new_model.load_for_finetune(f, n_classes=200, strict=True)
 
 		self.assertTrue(test_utils.is_all_equal(model, new_model))
 
@@ -37,9 +38,9 @@ class PretrainedTests(unittest.TestCase):
 		self.assertTrue(test_utils.is_any_different(model, new_model))
 
 		with self.mem_file() as f:
-			npz.save_npz(f, model)
+			model.save(f)
 			f.seek(0)
-			new_model.load_for_inference(f, n_classes=200)
+			new_model.load_for_inference(f, n_classes=200, strict=True)
 
 		self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True))