Forráskód Böngészése

added test for loading default pre-trained weights

Dimitri Korsch 4 éve
szülő
commit
a179c9f57d
2 módosított fájl, 42 hozzáadás és 11 törlés
  1. 10 7
      cvmodelz/models/factory.py
  2. 32 4
      tests/model_tests/factory_tests.py

+ 10 - 7
cvmodelz/models/factory.py

@@ -61,14 +61,11 @@ class ModelFactory(abc.ABC):
 			raise ValueError(f"Could not find {model_type}!")
 
 		if model_cls in cls.supported["chainer"]:
-			if "pretrained_model" not in kwargs:
-				kwargs["pretrained_model"] = None
+			kwargs["pretrained_model"] = kwargs.get("pretrained_model", None)
 			kwargs.pop("input_size", None)
 
 		elif model_cls in cls.supported["chainercv2"]:
-			if "pretrained" not in kwargs:
-				kwargs["pretrained"] = False
-
+			kwargs["pretrained"] = kwargs.get("pretrained", False)
 			input_size = kwargs.pop("input_size", None)
 			return ModelWrapper(model_cls(*args, **kwargs), input_size=input_size)
 
@@ -100,13 +97,19 @@ class ModelFactory(abc.ABC):
 		if key is not None:
 			return [f"{key}.{model_cls.__name__}" for model_cls in cls.supported[key]]
 
+		return cls.get_models(cls.supported.keys())
+
+	@classmethod
+	def get_models(cls, keys=None):
+		if keys is None:
+			keys = cls.supported.keys()
+
 		res = []
-		for key in cls.supported:
+		for key in keys:
 			res += cls.get_all_models(key)
 
 		return res
 
 
-
 if __name__ == '__main__':
 	print(pyaml.dump(dict(Models=ModelFactory.get_all_models()), indent=2))

+ 32 - 4
tests/model_tests/factory_tests.py

@@ -1,14 +1,42 @@
 import unittest
 
+from contextlib import contextmanager
+
 from cvmodelz.models import ModelFactory
 
 
+@contextmanager
+def clear_print(msg):
+	print(msg)
+	yield
+	print("\033[A{}\033[A".format(" "*len(msg)))
+
 class FactoryTests(unittest.TestCase):
 
 	def test_model_creation(self):
 		for key in ModelFactory.get_all_models():
-			msg = f"Creating {key}..."
-			print(msg)
-			model = ModelFactory.new(key)
-			print("\033[A{}\033[A".format(" "*len(msg)))
+
+			with clear_print(f"Creating {key}..."):
+				model = ModelFactory.new(key)
+
 			self.assertIsNotNone(model)
+
+	def test_cv2model_load(self):
+		for key in ModelFactory.get_models(["chainercv2"]):
+			with clear_print(f"Loading default weights for {key}..."):
+
+				model_rnd = ModelFactory.new(key, pretrained=False)
+				model_loaded1 = ModelFactory.new(key, pretrained=True)
+				model_loaded2 = ModelFactory.new(key, pretrained=True)
+
+				params_rnd = dict(model_rnd.namedparams())
+				params_loaded1 = dict(model_loaded1.namedparams())
+				params_loaded2 = dict(model_loaded2.namedparams())
+
+				for name, param in params_rnd.items():
+					loaded1 = params_loaded1[name]
+					loaded2 = params_loaded2[name]
+
+					self.assertTrue((  param.array != loaded1.array).any())
+					self.assertTrue((  param.array != loaded2.array).any())
+					self.assertTrue((loaded1.array == loaded2.array).all())