|
@@ -26,10 +26,11 @@ class SeparateModelClassifier(Classifier):
|
|
|
|
|
|
self.separate_model = self.model.copy(mode="copy")
|
|
|
|
|
|
- def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False) -> None:
|
|
|
+ def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False, **kwargs) -> None:
|
|
|
for model in [self.model, self.separate_model]:
|
|
|
model_loader = self.get_model_loader(finetune=finetune, model=model)
|
|
|
- model_loader(weights=weights_file, n_classes=n_classes, strict=True)
|
|
|
+ kwargs["strict"] = kwargs.get("strict", True)
|
|
|
+ model_loader(weights=weights_file, n_classes=n_classes, **kwargs)
|
|
|
|
|
|
def enable_only_head(self) -> None:
|
|
|
super().enable_only_head()
|