|
@@ -1,5 +1,6 @@
|
|
import abc
|
|
import abc
|
|
import chainer
|
|
import chainer
|
|
|
|
+import io
|
|
|
|
|
|
from chainer import functions as F
|
|
from chainer import functions as F
|
|
from chainer.serializers import npz
|
|
from chainer.serializers import npz
|
|
@@ -71,6 +72,12 @@ class Classifier(chainer.Chain):
|
|
pass
|
|
pass
|
|
|
|
|
|
def load_classifier(self, weights_file):
|
|
def load_classifier(self, weights_file):
|
|
|
|
+ def load_classifier(self, weights_file: str):
|
|
|
|
+
|
|
|
|
+ if isinstance(weights_file, io.BufferedIOBase):
|
|
|
|
+ assert not weights_file.closed, "The weights file was already closed!"
|
|
|
|
+ weights_file.seek(0)
|
|
|
|
+
|
|
npz.load_npz(weights_file, self, strict=True)
|
|
npz.load_npz(weights_file, self, strict=True)
|
|
|
|
|
|
def load_model(self, weights_file, n_classes, *, finetune: bool = False):
|
|
def load_model(self, weights_file, n_classes, *, finetune: bool = False):
|