Просмотр исходного кода

added new option for weights loading

Dimitri Korsch 6 лет назад
Родитель
Сommit
c35febcbeb
2 измененных файлов с 15 добавлено и 8 удалено
  1. 13 7
      cvfinetune/finetuner/base.py
  2. 2 1
      cvfinetune/parser.py

+ 13 - 7
cvfinetune/finetuner/base.py

@@ -132,13 +132,19 @@ class _ModelMixin(abc.ABC):
 				msg = "Loading already fine-tuned weights from \"{}\""
 				loader_func = self.model.load_for_inference
 			else:
-				self.weights = join(
-					self.data_info.BASE_DIR,
-					self.data_info.MODEL_DIR,
-					self.model_info.folder,
-					self.model_info.weights
-				)
-				msg = "Loading pre-trained weights \"{}\""
+				if args.weights:
+					msg = "Loading custom pre-trained weights \"{}\""
+					self.weights = args.weights
+
+				else:
+					msg = "Loading default pre-trained weights \"{}\""
+					self.weights = join(
+						self.data_info.BASE_DIR,
+						self.data_info.MODEL_DIR,
+						self.model_info.folder,
+						self.model_info.weights
+					)
+
 				loader_func = self.model.load_for_finetune
 
 			logging.info(msg.format(self.weights))

+ 2 - 1
cvfinetune/parser.py

@@ -33,7 +33,8 @@ def default_factory(extra_list=[]):
 			PoolingType.as_arg("pooling",
 				help_text="type of pre-classification pooling"),
 
-			Arg("--load", type=str, help="ignore weights and load already fine-tuned model"),
+			Arg("--load", type=str, help="ignore weights and load already fine-tuned model (classifier will NOT be re-initialized and number of classes will be unchanged)"),
+			Arg("--weights", type=str, help="ignore default weights and load already pre-trained model (classifier will be re-initialized and number of classes will be changed)"),
 			Arg("--headless", action="store_true", help="ignores classifier layer during loading"),
 
 			Arg("--n_jobs", "-j", type=int, default=0,