Переглянути джерело

added tests for MeanModelClassifier loading

Dimitri Korsch 4 роки тому
батько
коміт
a68f82b9f2

+ 2 - 0
.coveragerc

@@ -19,3 +19,5 @@ exclude_lines =
 
 
 omit:
 omit:
 	tests/*
 	tests/*
+	cvmodelz/utils/*
+	cvmodelz/_version.py

+ 19 - 19
cvmodelz/classifiers/base.py

@@ -47,15 +47,16 @@ class Classifier(chainer.Chain):
 		""" Loading a classifier has following use cases:
 		""" Loading a classifier has following use cases:
 
 
 			(0) No loading.
 			(0) No loading.
-				Here the all weights are initilized randomly.
+				- All weights are initilized randomly.
 
 
 			(1) Loading from default pre-trained weights
 			(1) Loading from default pre-trained weights
-				Here, the weights are loaded directly to
-				the model. Any additional not model-related
+				- The weights are loaded directly to
+				the model.
+				- Any additional not model-related
 				layer will be initialized randomly.
 				layer will be initialized randomly.
 
 
 			(2) Loading from a saved classifier.
 			(2) Loading from a saved classifier.
-				Here, all weights are loaded as-it-is from
+				- All weights are loaded as-it-is from
 				the given file.
 				the given file.
 		"""
 		"""
 
 
@@ -64,14 +65,15 @@ class Classifier(chainer.Chain):
 			self.load_classifier(weights_file)
 			self.load_classifier(weights_file)
 
 
 		except KeyError as e:
 		except KeyError as e:
-			# Case (1)
-			self.load_model(weights_file, n_classes=n_classes, finetune=finetune)
-
-		else:
-			# Case (0)
 			pass
 			pass
 
 
-	def load_classifier(self, weights_file):
+		# Case (1)
+		self.load_model(weights_file, n_classes=n_classes, finetune=finetune)
+
+		# else:
+		# 	# Case (0)
+		# 	pass
+
 	def load_classifier(self, weights_file: str):
 	def load_classifier(self, weights_file: str):
 
 
 		if isinstance(weights_file, io.BufferedIOBase):
 		if isinstance(weights_file, io.BufferedIOBase):
@@ -80,19 +82,17 @@ class Classifier(chainer.Chain):
 
 
 		npz.load_npz(weights_file, self, strict=True)
 		npz.load_npz(weights_file, self, strict=True)
 
 
-	def load_model(self, weights_file, n_classes, *, finetune: bool = False):
+	def get_model_loader(self, finetune: bool = False, model: BaseModel = None):
+		model = model or self.model
 		if finetune:
 		if finetune:
-			model_loader = self.model.load_for_finetune
+			return model.load_for_finetune
 		else:
 		else:
-			model_loader = self.model.load_for_inference
-
-		try:
-			model_loader(weights=weights_file, n_classes=n_classes, strict=True)
-		except KeyError as e:
+			return model.load_for_inference
 
 
-			breakpoint()
-			raise
 
 
+	def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False):
+		model_loader = self.get_model_loader(finetune=finetune, model=self.model)
+		model_loader(weights=weights_file, n_classes=n_classes, strict=True)
 
 
 	@property
 	@property
 	def feat_size(self) -> int:
 	def feat_size(self) -> int:

+ 4 - 26
cvmodelz/classifiers/separate_model_classifier.py

@@ -32,32 +32,10 @@ class SeparateModelClassifier(Classifier):
 	def load_for_inference(self, *args, **kwargs) -> None:
 	def load_for_inference(self, *args, **kwargs) -> None:
 		super().load_for_inference(*args, **kwargs)
 		super().load_for_inference(*args, **kwargs)
 
 
-
-
-	def load(self, *args, **kwargs) -> None:
-		raise NotImplementedError
-
-	def loader(self, model_loader: Callable) -> Callable:
-		super_loader = super(SeparateModelClassifier).loader(model_loader)
-
-		def inner_loader(n_classes: int, feat_size: int) -> None:
-			# use the given feature size here
-			super_loader(n_classes=n_classes, feat_size=feat_size)
-
-			# use the given feature size first ...
-			self.separate_model.reinitialize_clf(
-				n_classes=n_classes,
-				feat_size=feat_size)
-
-			# then copy model params ...
-			self.separate_model.copyparams(self.model)
-
-			# now use the default feature size to re-init the classifier
-			self.separate_model.reinitialize_clf(
-				n_classes=n_classes,
-				feat_size=self.feat_size)
-
-		return inner_loader
+	def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False) -> None:
+		for model in [self.model, self.separate_model]:
+			model_loader = self.get_model_loader(finetune=finetune, model=model)
+			model_loader(weights=weights_file, n_classes=n_classes, strict=True)
 
 
 	def enable_only_head(self) -> None:
 	def enable_only_head(self) -> None:
 		super().enable_only_head()
 		super().enable_only_head()

+ 23 - 0
tests/classifier_tests/loading.py

@@ -4,6 +4,7 @@ import test_utils
 import unittest
 import unittest
 
 
 from cvmodelz.classifiers import Classifier
 from cvmodelz.classifiers import Classifier
+from cvmodelz.classifiers import MeanModelClassifier
 from cvmodelz.models import ModelFactory
 from cvmodelz.models import ModelFactory
 
 
 class ClassifierLoadingTests(unittest.TestCase):
 class ClassifierLoadingTests(unittest.TestCase):
@@ -28,6 +29,7 @@ class ClassifierLoadingTests(unittest.TestCase):
 		""" if finetune is True, then the shapes of the classification
 		""" if finetune is True, then the shapes of the classification
 		layer differ, hence, strict should be False """
 		layer differ, hence, strict should be False """
 		self.assertTrue(*test_utils.is_all_equal(model, clf.model, strict=not finetune))
 		self.assertTrue(*test_utils.is_all_equal(model, clf.model, strict=not finetune))
+		return model, clf
 
 
 	def load_for_finetune(self, key):
 	def load_for_finetune(self, key):
 		self.load_model(key, finetune=True)
 		self.load_model(key, finetune=True)
@@ -35,8 +37,29 @@ class ClassifierLoadingTests(unittest.TestCase):
 	def load_for_inference(self, key):
 	def load_for_inference(self, key):
 		self.load_model(key, finetune=False)
 		self.load_model(key, finetune=False)
 
 
+class MeanModelClassifierLoadingTests(ClassifierLoadingTests):
+
+	def new_clf(self, key, **kwargs):
+		model = ModelFactory.new(key, pretrained_model=None)
+		return model, MeanModelClassifier(model, **kwargs)
+
+	def load_for_finetune(self, key):
+		_, clf = self.load_model(key, finetune=True)
+		self.assertTrue(*test_utils.is_all_equal(clf.model, clf.separate_model, strict=True, exclude_clf=True))
+
+
+	def load_for_inference(self, key):
+		_, clf = self.load_model(key, finetune=False)
+		self.assertTrue(*test_utils.is_all_equal(clf.model, clf.separate_model, strict=True))
+
 test_utils.add_tests(ClassifierLoadingTests.load_for_finetune,
 test_utils.add_tests(ClassifierLoadingTests.load_for_finetune,
 	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
 	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
 
 
+test_utils.add_tests(MeanModelClassifierLoadingTests.load_for_finetune,
+	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
+
 test_utils.add_tests(ClassifierLoadingTests.load_for_inference,
 test_utils.add_tests(ClassifierLoadingTests.load_for_inference,
 	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
 	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
+
+test_utils.add_tests(MeanModelClassifierLoadingTests.load_for_inference,
+	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))

+ 6 - 2
tests/test_utils.py

@@ -27,11 +27,13 @@ def add_tests(func, model_list) -> None:
 		new_func.__name__ = name
 		new_func.__name__ = name
 		setattr(cls, name, new_func)
 		setattr(cls, name, new_func)
 
 
-def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
+def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False, exclude_clf: bool = False) -> bool:
 	params0 = dict(model0.namedparams())
 	params0 = dict(model0.namedparams())
 	params1 = dict(model1.namedparams())
 	params1 = dict(model1.namedparams())
 
 
 	for name in params0:
 	for name in params0:
+		if exclude_clf and model0.clf_layer_name in name:
+			continue
 		param0, param1 = params0[name], params1[name]
 		param0, param1 = params0[name], params1[name]
 		if param0.shape != param1.shape:
 		if param0.shape != param1.shape:
 			if strict:
 			if strict:
@@ -44,11 +46,13 @@ def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = Fa
 
 
 	return True, "All equal!"
 	return True, "All equal!"
 
 
-def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
+def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False, exclude_clf: bool = False) -> bool:
 	params0 = dict(model0.namedparams())
 	params0 = dict(model0.namedparams())
 	params1 = dict(model1.namedparams())
 	params1 = dict(model1.namedparams())
 
 
 	for name in params0:
 	for name in params0:
+		if exclude_clf and model0.clf_layer_name in name:
+			continue
 		param0, param1 = params0[name], params1[name]
 		param0, param1 = params0[name], params1[name]
 		# print(name)
 		# print(name)
 		if param0.shape != param1.shape:
 		if param0.shape != param1.shape: