|
@@ -6,11 +6,8 @@ from contextlib import contextmanager
|
|
|
|
|
|
from cvmodelz.models import ModelFactory
|
|
|
|
|
|
-class LoadingTests(unittest.TestCase):
|
|
|
+class ModelLoadingTests(unittest.TestCase):
|
|
|
|
|
|
- @contextmanager
|
|
|
- def mem_file(self) -> io.BytesIO:
|
|
|
- yield io.BytesIO()
|
|
|
|
|
|
def cv2model_load_pretrained(self, key):
|
|
|
|
|
@@ -36,36 +33,36 @@ class LoadingTests(unittest.TestCase):
|
|
|
model = ModelFactory.new(key, n_classes=1000)
|
|
|
new_model = ModelFactory.new(key, n_classes=1000)
|
|
|
|
|
|
- self.assertTrue(test_utils.is_any_different(model, new_model))
|
|
|
+ self.assertTrue(*test_utils.is_any_different(model, new_model))
|
|
|
|
|
|
- with self.mem_file() as f:
|
|
|
+ with test_utils.memory_file() as f:
|
|
|
model.save(f)
|
|
|
f.seek(0)
|
|
|
new_model.load_for_finetune(f, n_classes=200, strict=True)
|
|
|
|
|
|
- self.assertTrue(test_utils.is_all_equal(model, new_model))
|
|
|
+ self.assertTrue(*test_utils.is_all_equal(model, new_model))
|
|
|
|
|
|
def load_for_inference(self, key):
|
|
|
|
|
|
model = ModelFactory.new(key, n_classes=200)
|
|
|
new_model = ModelFactory.new(key, n_classes=1000)
|
|
|
|
|
|
- self.assertTrue(test_utils.is_any_different(model, new_model))
|
|
|
+ self.assertTrue(*test_utils.is_any_different(model, new_model))
|
|
|
|
|
|
- with self.mem_file() as f:
|
|
|
+ with test_utils.memory_file() as f:
|
|
|
model.save(f)
|
|
|
f.seek(0)
|
|
|
new_model.load_for_inference(f, n_classes=200, strict=True)
|
|
|
|
|
|
- self.assertTrue(test_utils.is_all_equal(model, new_model, strict=True))
|
|
|
+ self.assertTrue(*test_utils.is_all_equal(model, new_model, strict=True))
|
|
|
|
|
|
|
|
|
|
|
|
-test_utils.add_tests(LoadingTests.load_for_finetune,
|
|
|
+test_utils.add_tests(ModelLoadingTests.load_for_finetune,
|
|
|
model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
|
|
|
|
|
|
-test_utils.add_tests(LoadingTests.load_for_inference,
|
|
|
+test_utils.add_tests(ModelLoadingTests.load_for_inference,
|
|
|
model_list=ModelFactory.get_models(["chainercv2", "cvmodelz"]))
|
|
|
|
|
|
-test_utils.add_tests(LoadingTests.cv2model_load_pretrained,
|
|
|
+test_utils.add_tests(ModelLoadingTests.cv2model_load_pretrained,
|
|
|
model_list=ModelFactory.get_models(["chainercv2"]))
|