浏览代码

added tests for MeanModelClassifier loading

Dimitri Korsch 4 年之前
父节点
当前提交
a68f82b9f2
共有 5 个文件被更改,包括 54 次插入47 次删除
  1. 2 0
      .coveragerc
  2. 19 19
      cvmodelz/classifiers/base.py
  3. 4 26
      cvmodelz/classifiers/separate_model_classifier.py
  4. 23 0
      tests/classifier_tests/loading.py
  5. 6 2
      tests/test_utils.py

+ 2 - 0
.coveragerc

@@ -19,3 +19,5 @@ exclude_lines =
 
 omit:
 	tests/*
+	cvmodelz/utils/*
+	cvmodelz/_version.py

+ 19 - 19
cvmodelz/classifiers/base.py

@@ -47,15 +47,16 @@ class Classifier(chainer.Chain):
 		""" Loading a classifier has following use cases:
 
 			(0) No loading.
-				Here the all weights are initilized randomly.
+				- All weights are initilized randomly.
 
 			(1) Loading from default pre-trained weights
-				Here, the weights are loaded directly to
-				the model. Any additional not model-related
+				- The weights are loaded directly to
+				the model.
+				- Any additional not model-related
 				layer will be initialized randomly.
 
 			(2) Loading from a saved classifier.
-				Here, all weights are loaded as-it-is from
+				- All weights are loaded as-it-is from
 				the given file.
 		"""
 
@@ -64,14 +65,15 @@ class Classifier(chainer.Chain):
 			self.load_classifier(weights_file)
 
 		except KeyError as e:
-			# Case (1)
-			self.load_model(weights_file, n_classes=n_classes, finetune=finetune)
-
-		else:
-			# Case (0)
 			pass
 
-	def load_classifier(self, weights_file):
+		# Case (1)
+		self.load_model(weights_file, n_classes=n_classes, finetune=finetune)
+
+		# else:
+		# 	# Case (0)
+		# 	pass
+
 	def load_classifier(self, weights_file: str):
 
 		if isinstance(weights_file, io.BufferedIOBase):
@@ -80,19 +82,17 @@ class Classifier(chainer.Chain):
 
 		npz.load_npz(weights_file, self, strict=True)
 
-	def load_model(self, weights_file, n_classes, *, finetune: bool = False):
+	def get_model_loader(self, finetune: bool = False, model: BaseModel = None):
+		model = model or self.model
 		if finetune:
-			model_loader = self.model.load_for_finetune
+			return model.load_for_finetune
 		else:
-			model_loader = self.model.load_for_inference
-
-		try:
-			model_loader(weights=weights_file, n_classes=n_classes, strict=True)
-		except KeyError as e:
+			return model.load_for_inference
 
-			breakpoint()
-			raise
 
+	def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False):
+		model_loader = self.get_model_loader(finetune=finetune, model=self.model)
+		model_loader(weights=weights_file, n_classes=n_classes, strict=True)
 
 	@property
 	def feat_size(self) -> int:

+ 4 - 26
cvmodelz/classifiers/separate_model_classifier.py

@@ -32,32 +32,10 @@ class SeparateModelClassifier(Classifier):
 	def load_for_inference(self, *args, **kwargs) -> None:
 		super().load_for_inference(*args, **kwargs)
 
-
-
-	def load(self, *args, **kwargs) -> None:
-		raise NotImplementedError
-
-	def loader(self, model_loader: Callable) -> Callable:
-		super_loader = super(SeparateModelClassifier).loader(model_loader)
-
-		def inner_loader(n_classes: int, feat_size: int) -> None:
-			# use the given feature size here
-			super_loader(n_classes=n_classes, feat_size=feat_size)
-
-			# use the given feature size first ...
-			self.separate_model.reinitialize_clf(
-				n_classes=n_classes,
-				feat_size=feat_size)
-
-			# then copy model params ...
-			self.separate_model.copyparams(self.model)
-
-			# now use the default feature size to re-init the classifier
-			self.separate_model.reinitialize_clf(
-				n_classes=n_classes,
-				feat_size=self.feat_size)
-
-		return inner_loader
+	def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False) -> None:
+		for model in [self.model, self.separate_model]:
+			model_loader = self.get_model_loader(finetune=finetune, model=model)
+			model_loader(weights=weights_file, n_classes=n_classes, strict=True)
 
 	def enable_only_head(self) -> None:
 		super().enable_only_head()

+ 23 - 0
tests/classifier_tests/loading.py

@@ -4,6 +4,7 @@ import test_utils
 import unittest
 
 from cvmodelz.classifiers import Classifier
+from cvmodelz.classifiers import MeanModelClassifier
 from cvmodelz.models import ModelFactory
 
 class ClassifierLoadingTests(unittest.TestCase):
@@ -28,6 +29,7 @@ class ClassifierLoadingTests(unittest.TestCase):
 		""" if finetune is True, then the shapes of the classification
 		layer differ, hence, strict should be False """
 		self.assertTrue(*test_utils.is_all_equal(model, clf.model, strict=not finetune))
+		return model, clf
 
 	def load_for_finetune(self, key):
 		self.load_model(key, finetune=True)
@@ -35,8 +37,29 @@ class ClassifierLoadingTests(unittest.TestCase):
 	def load_for_inference(self, key):
 		self.load_model(key, finetune=False)
 
+class MeanModelClassifierLoadingTests(ClassifierLoadingTests):
+
+	def new_clf(self, key, **kwargs):
+		model = ModelFactory.new(key, pretrained_model=None)
+		return model, MeanModelClassifier(model, **kwargs)
+
+	def load_for_finetune(self, key):
+		_, clf = self.load_model(key, finetune=True)
+		self.assertTrue(*test_utils.is_all_equal(clf.model, clf.separate_model, strict=True, exclude_clf=True))
+
+
+	def load_for_inference(self, key):
+		_, clf = self.load_model(key, finetune=False)
+		self.assertTrue(*test_utils.is_all_equal(clf.model, clf.separate_model, strict=True))
+
 test_utils.add_tests(ClassifierLoadingTests.load_for_finetune,
 	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
 
+test_utils.add_tests(MeanModelClassifierLoadingTests.load_for_finetune,
+	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
+
 test_utils.add_tests(ClassifierLoadingTests.load_for_inference,
 	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))
+
+test_utils.add_tests(MeanModelClassifierLoadingTests.load_for_inference,
+	model_list=ModelFactory.get_models(["cvmodelz", "chainercv2"]))

+ 6 - 2
tests/test_utils.py

@@ -27,11 +27,13 @@ def add_tests(func, model_list) -> None:
 		new_func.__name__ = name
 		setattr(cls, name, new_func)
 
-def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
+def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False, exclude_clf: bool = False) -> bool:
 	params0 = dict(model0.namedparams())
 	params1 = dict(model1.namedparams())
 
 	for name in params0:
+		if exclude_clf and model0.clf_layer_name in name:
+			continue
 		param0, param1 = params0[name], params1[name]
 		if param0.shape != param1.shape:
 			if strict:
@@ -44,11 +46,13 @@ def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = Fa
 
 	return True, "All equal!"
 
-def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
+def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False, exclude_clf: bool = False) -> bool:
 	params0 = dict(model0.namedparams())
 	params1 = dict(model1.namedparams())
 
 	for name in params0:
+		if exclude_clf and model0.clf_layer_name in name:
+			continue
 		param0, param1 = params0[name], params1[name]
 		# print(name)
 		if param0.shape != param1.shape: