Browse Source

added tests for weights loading

Dimitri Korsch 4 years ago
parent
commit
07d87ffd84

+ 3 - 0
cvmodelz/models/base.py

@@ -1,8 +1,11 @@
 import abc
 import chainer
+import numpy as np
 
 from chainer import functions as F
 from chainer import links as L
+from chainer.initializers import HeNormal
+from chainer.serializers import npz
 from chainer_addons.links.pooling import PoolingType # TODO: replace this!
 from collections import OrderedDict
 from typing import Callable

+ 5 - 5
cvmodelz/models/factory.py

@@ -17,8 +17,8 @@ class ModelFactory(abc.ABC):
 			L.ResNet50Layers,
 			L.ResNet101Layers,
 			L.ResNet152Layers,
-			L.VGG16Layers,
-			L.VGG19Layers,
+			# L.VGG16Layers,
+			# L.VGG19Layers,
 		),
 
 		chainercv=(
@@ -27,14 +27,14 @@ class ModelFactory(abc.ABC):
 
 		chainercv2=(
 			cv2resnet.resnet50,
-			cv2resnet.resnet50b,
+			# cv2resnet.resnet50b,
 
 			cv2inceptionv3.inceptionv3,
 		),
 
 		cvmodelz=(
-			pretrained.VGG16,
-			pretrained.VGG19,
+			# pretrained.VGG16,
+			# pretrained.VGG19,
 
 			pretrained.ResNet35,
 			pretrained.ResNet50,

+ 7 - 2
cvmodelz/models/pretrained/resnet.py

@@ -26,6 +26,11 @@ class BaseResNet(PretrainedModelMixin):
 			classifier_layers=["fc6"],
 		)
 
+	def init_extra_layers(self, n_classes, **kwargs):
+		if hasattr(self, "fc6"):
+			delattr(self, "fc6")
+		self.fc6 = L.Linear(2048, n_classes)
+
 	@property
 	def functions(self):
 		return super(BaseResNet, self).functions
@@ -34,14 +39,14 @@ class BaseResNet(PretrainedModelMixin):
 class ResNet35(BaseResNet, chainer.Chain):
 	n_layers = 35
 
-	def init_extra_layers(self, n_classes, **kwargs):
+	def init_extra_layers(self, *args, **kwargs):
+		super(ResNet35, self).init_extra_layers(*args, **kwargs)
 		self.conv1 = L.Convolution2D(3, 64, 7, 2, 3, **kwargs)
 		self.bn1 = L.BatchNormalization(64)
 		self.res2 = BuildingBlock(2, 64, 64, 256, 1, **kwargs)
 		self.res3 = BuildingBlock(3, 256, 128, 512, 2, **kwargs)
 		self.res4 = BuildingBlock(3, 512, 256, 1024, 2, **kwargs)
 		self.res5 = BuildingBlock(3, 1024, 512, 2048, 2, **kwargs)
-		self.fc6 = L.Linear(2048, n_classes)
 
 	@property
 	def functions(self):

+ 22 - 27
tests/model_tests/factory_tests.py

@@ -1,42 +1,37 @@
 import unittest
-
-from contextlib import contextmanager
+import test_utils
 
 from cvmodelz.models import ModelFactory
 
 
-@contextmanager
-def clear_print(msg):
-	print(msg)
-	yield
-	print("\033[A{}\033[A".format(" "*len(msg)))
 
 class FactoryTests(unittest.TestCase):
 
-	def test_model_creation(self):
-		for key in ModelFactory.get_all_models():
+	def model_creation(self, key):
+
+		model = ModelFactory.new(key)
+		self.assertIsNotNone(model)
 
-			with clear_print(f"Creating {key}..."):
-				model = ModelFactory.new(key)
+	def cv2model_load(self, key):
 
-			self.assertIsNotNone(model)
+		model_rnd = ModelFactory.new(key, pretrained=False)
+		model_loaded1 = ModelFactory.new(key, pretrained=True)
+		model_loaded2 = ModelFactory.new(key, pretrained=True)
 
-	def test_cv2model_load(self):
-		for key in ModelFactory.get_models(["chainercv2"]):
-			with clear_print(f"Loading default weights for {key}..."):
+		params_rnd = dict(model_rnd.namedparams())
+		params_loaded1 = dict(model_loaded1.namedparams())
+		params_loaded2 = dict(model_loaded2.namedparams())
 
-				model_rnd = ModelFactory.new(key, pretrained=False)
-				model_loaded1 = ModelFactory.new(key, pretrained=True)
-				model_loaded2 = ModelFactory.new(key, pretrained=True)
+		for name, param in params_rnd.items():
+			loaded1 = params_loaded1[name]
+			loaded2 = params_loaded2[name]
 
-				params_rnd = dict(model_rnd.namedparams())
-				params_loaded1 = dict(model_loaded1.namedparams())
-				params_loaded2 = dict(model_loaded2.namedparams())
+			self.assertTrue((  param.array != loaded1.array).any())
+			self.assertTrue((  param.array != loaded2.array).any())
+			self.assertTrue((loaded1.array == loaded2.array).all())
 
-				for name, param in params_rnd.items():
-					loaded1 = params_loaded1[name]
-					loaded2 = params_loaded2[name]
+test_utils.add_tests(FactoryTests.model_creation,
+	model_list=ModelFactory.get_all_models())
 
-					self.assertTrue((  param.array != loaded1.array).any())
-					self.assertTrue((  param.array != loaded2.array).any())
-					self.assertTrue((loaded1.array == loaded2.array).all())
+test_utils.add_tests(FactoryTests.cv2model_load,
+	model_list=ModelFactory.get_models(["chainercv2"]))

+ 46 - 2
tests/model_tests/pretrained_tests.py

@@ -1,8 +1,52 @@
+import io
 import unittest
+import test_utils
+
+from chainer.serializers import npz
+from contextlib import contextmanager
+
 
 from cvmodelz.models.pretrained import *
+from cvmodelz.models import ModelFactory
 
 class PretrainedTests(unittest.TestCase):
 
-	def test_foo(self):
-		self.assertTrue(1)
+	@contextmanager
+	def mem_file(self) -> io.BytesIO:
+		yield io.BytesIO()
+
+	def load_for_finetune(self, key):
+
+		model = ModelFactory.new(key)
+		new_model = ModelFactory.new(key)
+
+		self.assertTrue(test_utils.is_any_different(model, new_model))
+
+		with self.mem_file() as f:
+			npz.save_npz(f, model)
+			f.seek(0)
+			new_model.load_for_finetune(f, n_classes=200)
+
+		self.assertTrue(test_utils.is_all_equal(model, new_model))
+
+	def load_for_inference(self, key):
+
+		model = ModelFactory.new(key, n_classes=200)
+		new_model = ModelFactory.new(key, n_classes=1000)
+
+		self.assertTrue(test_utils.is_any_different(model, new_model))
+
+		with self.mem_file() as f:
+			npz.save_npz(f, model)
+			f.seek(0)
+			new_model.load_for_inference(f, n_classes=200)
+
+		self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True))
+
+
+
+test_utils.add_tests(PretrainedTests.load_for_finetune,
+	model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
+
+test_utils.add_tests(PretrainedTests.load_for_inference,
+	model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))

+ 67 - 0
tests/test_utils.py

@@ -0,0 +1,67 @@
+import inspect
+import chainer
+
+from contextlib import contextmanager
+from functools import partial
+
+def get_class_that_defined_method(meth):
+	if inspect.isfunction(meth):
+		cls_name = meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0]
+		return getattr(inspect.getmodule(meth), cls_name, None)
+
+def wrapper(func, key):
+
+	def inner(self):
+		return func(self, key)
+
+	return inner
+
+def add_tests(func, model_list) -> None:
+
+	cls = get_class_that_defined_method(func)
+
+	for key in model_list:
+		new_func = wrapper(func, key)
+		name = f"test_{key.replace('.', '__')}_{func.__name__}"
+		new_func.__name__ = name
+		setattr(cls, name, new_func)
+
+def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
+	params0 = dict(model0.namedparams())
+	params1 = dict(model1.namedparams())
+
+	for name in params0:
+		param0, param1 = params0[name], params1[name]
+		if param0.shape != param1.shape:
+			if strict:
+				return False
+			else:
+				continue
+
+		if not (param0.array == param1.array).all():
+			return False
+
+	return True
+
+def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
+	params0 = dict(model0.namedparams())
+	params1 = dict(model1.namedparams())
+
+	for name in params0:
+		param0, param1 = params0[name], params1[name]
+		# print(name)
+		if param0.shape != param1.shape:
+			if strict:
+				return False
+			else:
+				continue
+		if (param0.array != param1.array).any():
+			return True
+
+	return False
+
+@contextmanager
+def clear_print(msg):
+	print(msg)
+	yield
+	print("\033[A{}\033[A".format(" "*len(msg)))