|
@@ -43,7 +43,7 @@ class Classifier(chainer.Chain):
|
|
|
def save(self, weights_file):
|
|
|
npz.save_npz(weights_file, self)
|
|
|
|
|
|
- def load(self, weights_file: str, n_classes: int, *, finetune: bool = False) -> None:
|
|
|
+ def load(self, weights_file: str, n_classes: int, *, finetune: bool = False, **kwargs) -> None:
|
|
|
""" Loading a classifier has following use cases:
|
|
|
|
|
|
(0) No loading.
|
|
@@ -68,7 +68,7 @@ class Classifier(chainer.Chain):
|
|
|
pass
|
|
|
|
|
|
# Case (1)
|
|
|
- self.load_model(weights_file, n_classes=n_classes, finetune=finetune)
|
|
|
+ self.load_model(weights_file, n_classes=n_classes, finetune=finetune, **kwargs)
|
|
|
|
|
|
# else:
|
|
|
# # Case (0)
|
|
@@ -90,9 +90,10 @@ class Classifier(chainer.Chain):
|
|
|
return model.load_for_inference
|
|
|
|
|
|
|
|
|
- def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False):
|
|
|
+ def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False, **kwargs):
|
|
|
model_loader = self.get_model_loader(finetune=finetune, model=self.model)
|
|
|
- model_loader(weights=weights_file, n_classes=n_classes, strict=True)
|
|
|
+ kwargs["strict"] = kwargs.get("strict", True)
|
|
|
+ model_loader(weights=weights_file, n_classes=n_classes, **kwargs)
|
|
|
|
|
|
@property
|
|
|
def feat_size(self) -> int:
|