|
@@ -17,6 +17,10 @@ class SeparateModelClassifier(Classifier):
|
|
|
input data. Hence, the forward method is abstract!
|
|
|
"""
|
|
|
|
|
|
+ def __init__(self, *args, copy_mode="copy", **kwargs):
|
|
|
+ self.copy_mode = copy_mode
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+
|
|
|
@abc.abstractmethod
|
|
|
def forward(self, *args, **kwargs) -> chainer.Variable:
|
|
|
super().forward(*args, **kwargs)
|
|
@@ -24,7 +28,7 @@ class SeparateModelClassifier(Classifier):
|
|
|
def setup(self, model: BaseModel) -> None:
|
|
|
super().setup(model)
|
|
|
|
|
|
- self.separate_model = self.model.copy(mode="copy")
|
|
|
+ self.separate_model = self.model.copy(mode=self.copy_mode)
|
|
|
|
|
|
def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False, **kwargs) -> None:
|
|
|
for model in [self.model, self.separate_model]:
|