소스 검색

fixed minor issues

Dimitri Korsch 4 년 전
부모
커밋
fa0d6dec8a
2개의 변경된 파일4개의 추가작업 그리고 2개의 파일을 삭제
  1. 3 2
      cvmodelz/classifiers/separate_model_classifier.py
  2. 1 0
      cvmodelz/models/base.py

+ 3 - 2
cvmodelz/classifiers/separate_model_classifier.py

@@ -26,10 +26,11 @@ class SeparateModelClassifier(Classifier):
 
 		self.separate_model = self.model.copy(mode="copy")
 
-	def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False) -> None:
+	def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False, **kwargs) -> None:
 		for model in [self.model, self.separate_model]:
 			model_loader = self.get_model_loader(finetune=finetune, model=model)
-			model_loader(weights=weights_file, n_classes=n_classes, strict=True)
+			kwargs["strict"] = kwargs.get("strict", True)
+			model_loader(weights=weights_file, n_classes=n_classes, **kwargs)
 
 	def enable_only_head(self) -> None:
 		super().enable_only_head()

+ 1 - 0
cvmodelz/models/base.py

@@ -72,6 +72,7 @@ class BaseModel(abc.ABC, chainer.Chain):
 		w_shape = (n_classes, feat_size or clf_layer.W.shape[1])
 		dtype = clf_layer.W.dtype
 
+		clf_layer.in_size, clf_layer.in_size = w_shape
 		clf_layer.W.data = np.zeros(w_shape, dtype=dtype)
 		clf_layer.b.data = np.zeros(w_shape[0], dtype=dtype)