Explorar el Código

added copy mode option in the separate model classifier

Dimitri Korsch hace 4 años
padre
commit
cabc625d12
Se han modificado 1 ficheros con 5 adiciones y 1 borrados
  1. 5 1
      cvmodelz/classifiers/separate_model_classifier.py

+ 5 - 1
cvmodelz/classifiers/separate_model_classifier.py

@@ -17,6 +17,10 @@ class SeparateModelClassifier(Classifier):
 		input data. Hence, the forward method is abstract!
 	"""
 
+	def __init__(self, *args, copy_mode="copy", **kwargs):
+		self.copy_mode = copy_mode
+		super().__init__(*args, **kwargs)
+
 	@abc.abstractmethod
 	def forward(self, *args, **kwargs) -> chainer.Variable:
 		super().forward(*args, **kwargs)
@@ -24,7 +28,7 @@ class SeparateModelClassifier(Classifier):
 	def setup(self, model: BaseModel) -> None:
 		super().setup(model)
 
-		self.separate_model = self.model.copy(mode="copy")
+		self.separate_model = self.model.copy(mode=self.copy_mode)
 
 	def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False, **kwargs) -> None:
 		for model in [self.model, self.separate_model]: