Browse Source

added classifier loading tests. renamed modules of other tests

Dimitri Korsch 4 years ago
parent
commit
9f7bfcc42a

+ 48 - 2
cvmodelz/classifiers/base.py

@@ -2,6 +2,7 @@ import abc
 import chainer
 
 from chainer import functions as F
+from chainer.serializers import npz
 from typing import Dict
 from typing import Callable
 
@@ -38,8 +39,53 @@ class Classifier(chainer.Chain):
 	def n_classes(self) -> int:
 		return self.model.clf_layer.W.shape[0]
 
-	def loader(self, model_loader: Callable) -> Callable:
-		return model_loader
+	def save(self, weights_file):
+		npz.save_npz(weights_file, self)
+
+	def load(self, weights_file: str, n_classes: int, *, finetune: bool = False) -> None:
+		""" Loading a classifier has following use cases:
+
+			(0) No loading.
+				Here the all weights are initilized randomly.
+
+			(1) Loading from default pre-trained weights
+				Here, the weights are loaded directly to
+				the model. Any additional not model-related
+				layer will be initialized randomly.
+
+			(2) Loading from a saved classifier.
+				Here, all weights are loaded as-it-is from
+				the given file.
+		"""
+
+		try:
+			# Case (2)
+			self.load_classifier(weights_file)
+
+		except KeyError as e:
+			# Case (1)
+			self.load_model(weights_file, n_classes=n_classes, finetune=finetune)
+
+		else:
+			# Case (0)
+			pass
+
+	def load_classifier(self, weights_file):
+		npz.load_npz(weights_file, self, strict=True)
+
+	def load_model(self, weights_file, n_classes, *, finetune: bool = False):
+		if finetune:
+			model_loader = self.model.load_for_finetune
+		else:
+			model_loader = self.model.load_for_inference
+
+		try:
+			model_loader(weights=weights_file, n_classes=n_classes, strict=True)
+		except KeyError as e:
+
+			breakpoint()
+			raise
+
 
 	@property
 	def feat_size(self) -> int:

+ 11 - 0
cvmodelz/classifiers/separate_model_classifier.py

@@ -24,6 +24,17 @@ class SeparateModelClassifier(Classifier):
 
 		self.separate_model = self.model.copy(mode="copy")
 
+	def load_for_finetune(self, *args, **kwargs) -> None:
+		super(SeparateModelClassifier, self).load_for_finetune(*args, **kwargs)
+
+	def load_for_inference(self, *args, **kwargs) -> None:
+		super(SeparateModelClassifier, self).load_for_inference(*args, **kwargs)
+
+
+
+	def load(self, *args, **kwargs) -> None:
+		raise NotImplementedError
+
 	def loader(self, model_loader: Callable) -> Callable:
 		super_loader = super(SeparateModelClassifier).loader(model_loader)
 

+ 2 - 0
tests/classifier_tests/__init__.py

@@ -0,0 +1,2 @@
+from classifier_tests.creation import *
+from classifier_tests.loading import *

+ 3 - 3
tests/classifier_tests.py → tests/classifier_tests/creation.py

@@ -5,7 +5,7 @@ import unittest
 from cvmodelz.classifiers import Classifier
 from cvmodelz.models import ModelFactory
 
-class ClassifierTests(unittest.TestCase):
+class ClassifierCreationTests(unittest.TestCase):
 
 
 	def new_clf(self, key, **kwargs):
@@ -30,8 +30,8 @@ class ClassifierTests(unittest.TestCase):
 
 
 
-test_utils.add_tests(ClassifierTests.creation,
+test_utils.add_tests(ClassifierCreationTests.creation,
 	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
 
-test_utils.add_tests(ClassifierTests.loss_computation,
+test_utils.add_tests(ClassifierCreationTests.loss_computation,
 	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))

+ 43 - 0
tests/classifier_tests/loading.py

@@ -0,0 +1,43 @@
+import io
+import numpy as np
+import test_utils
+import unittest
+
+from cvmodelz.classifiers import Classifier
+from cvmodelz.models import ModelFactory
+
+class ClassifierLoadingTests(unittest.TestCase):
+
+	def new_clf(self, key, **kwargs):
+		model = ModelFactory.new(key, pretrained_model=None)
+		return model, Classifier(model, **kwargs)
+
+	def load_model(self, key, finetune):
+		init_cls = 1000 if finetune else 200
+		final_cls = 200
+
+		model = ModelFactory.new(key, pretrained_model=None, n_classes=init_cls)
+		_, clf = self.new_clf(key)
+
+		self.assertTrue(*test_utils.is_any_different(model, clf.model))
+
+		with test_utils.memory_file() as f:
+			model.save(f)
+			f.seek(0)
+			clf.load(f, n_classes=final_cls, finetune=finetune)
+
+		""" if finetune is True, then the shapes of the classification
+		layer differ, hence, strict should be False """
+		self.assertTrue(*test_utils.is_all_equal(model, clf.model, strict=not finetune))
+
+	def load_for_finetune(self, key):
+		self.load_model(key, finetune=True)
+
+	def load_for_inference(self, key):
+		self.load_model(key, finetune=False)
+
+test_utils.add_tests(ClassifierLoadingTests.load_for_finetune,
+	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
+
+test_utils.add_tests(ClassifierLoadingTests.load_for_inference,
+	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))

+ 3 - 3
tests/model_tests/creation.py

@@ -6,7 +6,7 @@ from cvmodelz.models.pretrained.base import PretrainedModelMixin
 from cvmodelz.models.wrapper import ModelWrapper
 
 
-class CreationsTests(unittest.TestCase):
+class ModelCreationsTests(unittest.TestCase):
 
 	def cv2model_creation(self, key):
 
@@ -21,9 +21,9 @@ class CreationsTests(unittest.TestCase):
 
 		self.assertIsInstance(model, PretrainedModelMixin)
 
-test_utils.add_tests(CreationsTests.cv2model_creation,
+test_utils.add_tests(ModelCreationsTests.cv2model_creation,
 	model_list=ModelFactory.get_models(["chainercv2"]))
 
-test_utils.add_tests(CreationsTests.pretrained_model_creation,
+test_utils.add_tests(ModelCreationsTests.pretrained_model_creation,
 	model_list=ModelFactory.get_models(["cvmodelz"]))
 

+ 10 - 13
tests/model_tests/loading.py

@@ -6,11 +6,8 @@ from contextlib import contextmanager
 
 from cvmodelz.models import ModelFactory
 
-class LoadingTests(unittest.TestCase):
+class ModelLoadingTests(unittest.TestCase):
 
-	@contextmanager
-	def mem_file(self) -> io.BytesIO:
-		yield io.BytesIO()
 
 	def cv2model_load_pretrained(self, key):
 
@@ -36,36 +33,36 @@ class LoadingTests(unittest.TestCase):
 		model = ModelFactory.new(key, n_classes=1000)
 		new_model = ModelFactory.new(key, n_classes=1000)
 
-		self.assertTrue(test_utils.is_any_different(model, new_model))
+		self.assertTrue(*test_utils.is_any_different(model, new_model))
 
-		with self.mem_file() as f:
+		with test_utils.memory_file() as f:
 			model.save(f)
 			f.seek(0)
 			new_model.load_for_finetune(f, n_classes=200, strict=True)
 
-		self.assertTrue(test_utils.is_all_equal(model, new_model))
+		self.assertTrue(*test_utils.is_all_equal(model, new_model))
 
 	def load_for_inference(self, key):
 
 		model = ModelFactory.new(key, n_classes=200)
 		new_model = ModelFactory.new(key, n_classes=1000)
 
-		self.assertTrue(test_utils.is_any_different(model, new_model))
+		self.assertTrue(*test_utils.is_any_different(model, new_model))
 
-		with self.mem_file() as f:
+		with test_utils.memory_file() as f:
 			model.save(f)
 			f.seek(0)
 			new_model.load_for_inference(f, n_classes=200, strict=True)
 
-		self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True))
+		self.assertTrue(*test_utils.is_all_equal(model, new_model, strict=True))
 
 
 
-test_utils.add_tests(LoadingTests.load_for_finetune,
+test_utils.add_tests(ModelLoadingTests.load_for_finetune,
 	model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
 
-test_utils.add_tests(LoadingTests.load_for_inference,
+test_utils.add_tests(ModelLoadingTests.load_for_inference,
 	model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
 
-test_utils.add_tests(LoadingTests.cv2model_load_pretrained,
+test_utils.add_tests(ModelLoadingTests.cv2model_load_pretrained,
 	model_list=ModelFactory.get_models(["chainercv2"]))

+ 13 - 7
tests/test_utils.py

@@ -1,5 +1,6 @@
-import inspect
 import chainer
+import inspect
+import io
 
 from contextlib import contextmanager
 from functools import partial
@@ -34,14 +35,14 @@ def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = Fa
 		param0, param1 = params0[name], params1[name]
 		if param0.shape != param1.shape:
 			if strict:
-				return False
+				return False, f"shape of {name} was not the same: {param0.shape} != {param1.shape}"
 			else:
 				continue
 
 		if not (param0.array == param1.array).all():
-			return False
+			return False, f"array of {name} was not the same"
 
-	return True
+	return True, "All equal!"
 
 def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
 	params0 = dict(model0.namedparams())
@@ -52,13 +53,18 @@ def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool
 		# print(name)
 		if param0.shape != param1.shape:
 			if strict:
-				return False
+				return False, f"shape of {name} was not the same: {param0.shape} != {param1.shape}"
 			else:
 				continue
 		if (param0.array != param1.array).any():
-			return True
+			return True, f"Difference in array {name} found."
+
+	return False, "All equal!"
 
-	return False
+
+@contextmanager
+def memory_file() -> io.BytesIO:
+	yield io.BytesIO()
 
 @contextmanager
 def clear_print(msg):