Forráskód Böngészése

refactored test modules

Dimitri Korsch 4 éve
szülő
commit
d47ca908c3

+ 2 - 3
tests/model_tests/__init__.py

@@ -1,4 +1,3 @@
-from model_tests.factory_tests import *
-from model_tests.pretrained_tests import *
-from model_tests.wrapper_tests import *
+from model_tests.creation import *
+from model_tests.loading import *
 

+ 29 - 0
tests/model_tests/creation.py

@@ -0,0 +1,29 @@
+import test_utils
+import unittest
+
+from cvmodelz.models import ModelFactory
+from cvmodelz.models.pretrained.base import PretrainedModelMixin
+from cvmodelz.models.wrapper import ModelWrapper
+
+
+class CreationsTests(unittest.TestCase):
+
+	def cv2model_creation(self, key):
+
+		model = ModelFactory.new(key)
+		self.assertIsNotNone(model)
+
+		self.assertIsInstance(model, ModelWrapper)
+
+	def pretrained_model_creation(self, key):
+		model = ModelFactory.new(key)
+		self.assertIsNotNone(model)
+
+		self.assertIsInstance(model, PretrainedModelMixin)
+
+test_utils.add_tests(CreationsTests.cv2model_creation,
+	model_list=ModelFactory.get_models(["chainercv2"]))
+
+test_utils.add_tests(CreationsTests.pretrained_model_creation,
+	model_list=ModelFactory.get_models(["cvmodelz"]))
+

+ 0 - 50
tests/model_tests/factory_tests.py

@@ -1,50 +0,0 @@
-import unittest
-import test_utils
-
-from cvmodelz.models import ModelFactory
-from cvmodelz.models.pretrained.base import PretrainedModelMixin
-from cvmodelz.models.wrapper import ModelWrapper
-
-
-class FactoryTests(unittest.TestCase):
-
-	def cv2model_creation(self, key):
-
-		model = ModelFactory.new(key)
-		self.assertIsNotNone(model)
-
-		self.assertIsInstance(model, ModelWrapper)
-
-	def pretrained_model_creation(self, key):
-		model = ModelFactory.new(key)
-		self.assertIsNotNone(model)
-
-		self.assertIsInstance(model, PretrainedModelMixin)
-
-	def cv2model_load(self, key):
-
-		model_rnd = ModelFactory.new(key, pretrained_model=None)
-		model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
-		model_loaded2 = ModelFactory.new(key, pretrained_model="auto")
-
-		params_rnd = dict(model_rnd.namedparams())
-		params_loaded1 = dict(model_loaded1.namedparams())
-		params_loaded2 = dict(model_loaded2.namedparams())
-
-		for name, param in params_rnd.items():
-			loaded1 = params_loaded1[name]
-			loaded2 = params_loaded2[name]
-
-			self.assertTrue((  param.array != loaded1.array).any())
-			self.assertTrue((  param.array != loaded2.array).any())
-			self.assertTrue((loaded1.array == loaded2.array).all())
-
-
-test_utils.add_tests(FactoryTests.cv2model_creation,
-	model_list=ModelFactory.get_models(["chainercv2"]))
-
-test_utils.add_tests(FactoryTests.pretrained_model_creation,
-	model_list=ModelFactory.get_models(["cvmodelz"]))
-
-test_utils.add_tests(FactoryTests.cv2model_load,
-	model_list=ModelFactory.get_models(["chainercv2"]))

+ 26 - 8
tests/model_tests/pretrained_tests.py → tests/model_tests/loading.py

@@ -1,21 +1,36 @@
 import io
-import numpy as np
-import unittest
 import test_utils
+import unittest
 
-from chainer.serializers import npz
 from contextlib import contextmanager
 
-
-from cvmodelz.models.pretrained import *
 from cvmodelz.models import ModelFactory
 
-class PretrainedTests(unittest.TestCase):
+class LoadingTests(unittest.TestCase):
 
 	@contextmanager
 	def mem_file(self) -> io.BytesIO:
 		yield io.BytesIO()
 
+	def cv2model_load_pretrained(self, key):
+
+		model_rnd = ModelFactory.new(key, pretrained_model=None)
+		model_loaded1 = ModelFactory.new(key, pretrained_model="auto")
+		model_loaded2 = ModelFactory.new(key, pretrained_model="auto")
+
+		params_rnd = dict(model_rnd.namedparams())
+		params_loaded1 = dict(model_loaded1.namedparams())
+		params_loaded2 = dict(model_loaded2.namedparams())
+
+		for name, param in params_rnd.items():
+			loaded1 = params_loaded1[name]
+			loaded2 = params_loaded2[name]
+
+			self.assertTrue((  param.array != loaded1.array).any())
+			self.assertTrue((  param.array != loaded2.array).any())
+			self.assertTrue((loaded1.array == loaded2.array).all())
+
+
 	def load_for_finetune(self, key):
 
 		model = ModelFactory.new(key, n_classes=1000)
@@ -46,8 +61,11 @@ class PretrainedTests(unittest.TestCase):
 
 
 
-test_utils.add_tests(PretrainedTests.load_for_finetune,
+test_utils.add_tests(LoadingTests.load_for_finetune,
 	model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
 
-test_utils.add_tests(PretrainedTests.load_for_inference,
+test_utils.add_tests(LoadingTests.load_for_inference,
 	model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
+
+test_utils.add_tests(LoadingTests.cv2model_load_pretrained,
+	model_list=ModelFactory.get_models(["chainercv2"]))

+ 0 - 9
tests/model_tests/wrapper_tests.py

@@ -1,9 +0,0 @@
-import unittest
-
-from cvmodelz.models.wrapper import ModelWrapper
-
-
-class WrapperTests(unittest.TestCase):
-
-	def test_foo(self):
-		self.assertTrue(1)