Forráskód Böngészése

moved reloading of file objects to the models and classifier loading methods

Dimitri Korsch 4 éve
szülő
commit
a9e5ef9ec2

+ 7 - 0
cvmodelz/classifiers/base.py

@@ -1,5 +1,6 @@
 import abc
 import chainer
+import io
 
 from chainer import functions as F
 from chainer.serializers import npz
@@ -71,6 +72,12 @@ class Classifier(chainer.Chain):
 			pass
 
 	def load_classifier(self, weights_file):
+	def load_classifier(self, weights_file: str):
+
+		if isinstance(weights_file, io.BufferedIOBase):
+			assert not weights_file.closed, "The weights file was already closed!"
+			weights_file.seek(0)
+
 		npz.load_npz(weights_file, self, strict=True)
 
 	def load_model(self, weights_file, n_classes, *, finetune: bool = False):

+ 5 - 0
cvmodelz/models/base.py

@@ -1,5 +1,6 @@
 import abc
 import chainer
+import io
 import numpy as np
 
 from chainer import functions as F
@@ -99,6 +100,10 @@ class BaseModel(abc.ABC, chainer.Chain):
 		if headless:
 			ignore_names = lambda name: name.startswith(path + self.clf_layer_name)
 
+		if isinstance(weights, io.BufferedIOBase):
+			assert not weights.closed, "The weights file was already closed!"
+			weights.seek(0)
+
 		npz.load_npz(weights, self.model_instance,
 			path=path, strict=strict, ignore_names=ignore_names)
 

+ 0 - 1
tests/classifier_tests/loading.py

@@ -23,7 +23,6 @@ class ClassifierLoadingTests(unittest.TestCase):
 
 		with test_utils.memory_file() as f:
 			model.save(f)
-			f.seek(0)
 			clf.load(f, n_classes=final_cls, finetune=finetune)
 
 		""" if finetune is True, then the shapes of the classification

+ 0 - 2
tests/model_tests/loading.py

@@ -37,7 +37,6 @@ class ModelLoadingTests(unittest.TestCase):
 
 		with test_utils.memory_file() as f:
 			model.save(f)
-			f.seek(0)
 			new_model.load_for_finetune(f, n_classes=200, strict=True)
 
 		self.assertTrue(*test_utils.is_all_equal(model, new_model))
@@ -51,7 +50,6 @@ class ModelLoadingTests(unittest.TestCase):
 
 		with test_utils.memory_file() as f:
 			model.save(f)
-			f.seek(0)
 			new_model.load_for_inference(f, n_classes=200, strict=True)
 
 		self.assertTrue(*test_utils.is_all_equal(model, new_model, strict=True))