Przeglądaj źródła

added pre-training command line argument and its handling

Dimitri Korsch 4 lat temu
rodzic
commit
5eed48e144

+ 4 - 2
cvfinetune/finetuner/mixins/model.py

@@ -137,8 +137,10 @@ class _ModelMixin(abc.ABC):
 		weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
 		weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
 
 
 		weights = model_info.weights
 		weights = model_info.weights
-		### TODO: make pre-training command line argument!
-		return str(weights_dir / weights.get("inat", weights.get("imagenet")))
+		assert opts.pre_training in weights, \
+			f"Weights for \"{opts.pre_training}\" pre-training were not found!"
+
+		return str(weights_dir / weights[opts.pre_training])
 
 
 
 
 	def load_weights(self, opts) -> None:
 	def load_weights(self, opts) -> None:

+ 5 - 0
cvfinetune/parser/model_args.py

@@ -17,6 +17,11 @@ def add_model_args(parser: BaseParser) -> None:
 			choices=choices,
 			choices=choices,
 			help="type of the model"),
 			help="type of the model"),
 
 
+		Arg("--pre_training", "-pt",
+			default="imagenet",
+			choices=["imagenet", "inat"],
+			help="type of model pre-training"),
+
 		Arg("--input_size", type=int, nargs="+", default=0,
 		Arg("--input_size", type=int, nargs="+", default=0,
 			help="overrides default input size of the model, if greater than 0"),
 			help="overrides default input size of the model, if greater than 0"),