Эх сурвалжийг харах

changed default info file handling

Dimitri Korsch 6 жил өмнө
parent
commit
0fb2657128

+ 1 - 1
cvfinetune/__init__.py

@@ -1 +1 @@
-__version__ = "0.3.0"
+__version__ = "0.3.2"

+ 63 - 42
cvfinetune/parser.py

@@ -1,6 +1,7 @@
 import os
 import logging
 import platform
+import warnings
 
 from chainer_addons.training import OptimizerType
 from chainer_addons.models import PrepareType
@@ -9,75 +10,95 @@ from chainer_addons.links import PoolingType
 from cvargparse import GPUParser, Arg, ArgFactory
 from cvdatasets.utils import read_info_file
 
-DEFAULT_INFO_FILE=os.environ.get("DATA", "/home/korsch/Data/info.yml")
+DEFAULT_INFO_FILE = os.environ.get("DATA")
 
-info_file = read_info_file(DEFAULT_INFO_FILE)
+if DEFAULT_INFO_FILE is not None and os.path.isfile(DEFAULT_INFO_FILE):
+	info_file = read_info_file(DEFAULT_INFO_FILE)
+else:
+	info_file = None
+
+WARNING = """Could not find default info file \"{}\". """ + \
+"""Some arguments (dataset, parts etc.) are not restraint to certain choices! """ + \
+"""You can set <DATA> environment variable to change the default info file location."""
 
 def default_factory(extra_list=[]):
-	return ArgFactory(extra_list + [
+	if info_file is None:
+		warnings.warn(WARNING.format(DEFAULT_INFO_FILE))
+		arg_list0 = [
+			Arg("data"),
+			Arg("dataset"),
+			Arg("parts"),
 
+			Arg("--model_type", "-mt",
+				default="resnet",
+				help="type of the model"),
+		]
+	else:
+		arg_list0 = [
 			Arg("data", default=DEFAULT_INFO_FILE),
-
 			Arg("dataset", choices=info_file.DATASETS.keys()),
 			Arg("parts", choices=info_file.PARTS.keys()),
 			Arg("--model_type", "-mt",
 				default="resnet", choices=info_file.MODELS.keys(),
 				help="type of the model"),
+		]
 
-			Arg("--input_size", type=int, nargs="+", default=0,
-				help="overrides default input size of the model, if greater than 0"),
+	arg_list1 = [
+		Arg("--input_size", type=int, nargs="+", default=0,
+			help="overrides default input size of the model, if greater than 0"),
 
-			PrepareType.as_arg("prepare_type",
-				help_text="type of image preprocessing"),
+		PrepareType.as_arg("prepare_type",
+			help_text="type of image preprocessing"),
 
-			PoolingType.as_arg("pooling",
-				help_text="type of pre-classification pooling"),
+		PoolingType.as_arg("pooling",
+			help_text="type of pre-classification pooling"),
 
-			Arg("--load", type=str, help="ignore weights and load already fine-tuned model (classifier will NOT be re-initialized and number of classes will be unchanged)"),
-			Arg("--weights", type=str, help="ignore default weights and load already pre-trained model (classifier will be re-initialized and number of classes will be changed)"),
-			Arg("--headless", action="store_true", help="ignores classifier layer during loading"),
+		Arg("--load", type=str, help="ignore weights and load already fine-tuned model (classifier will NOT be re-initialized and number of classes will be unchanged)"),
+		Arg("--weights", type=str, help="ignore default weights and load already pre-trained model (classifier will be re-initialized and number of classes will be changed)"),
+		Arg("--headless", action="store_true", help="ignores classifier layer during loading"),
 
-			Arg("--n_jobs", "-j", type=int, default=0,
-				help="number of loading processes. If 0, then images are loaded in the same process"),
+		Arg("--n_jobs", "-j", type=int, default=0,
+			help="number of loading processes. If 0, then images are loaded in the same process"),
 
-			Arg("--warm_up", type=int, help="warm up epochs"),
+		Arg("--warm_up", type=int, help="warm up epochs"),
 
-			OptimizerType.as_arg("optimizer", "opt",
-				help_text="type of the optimizer"),
+		OptimizerType.as_arg("optimizer", "opt",
+			help_text="type of the optimizer"),
 
-			Arg("--cosine_schedule", type=int,
-				default=-1,
-				help="enable cosine annealing LR schedule. This parameter sets the number of schedule stages"),
+		Arg("--cosine_schedule", type=int,
+			default=-1,
+			help="enable cosine annealing LR schedule. This parameter sets the number of schedule stages"),
 
-			Arg("--l1_loss", action="store_true",
-				help="(only with \"--only_head\" option!) use L1 Hinge Loss instead of Softmax Cross-Entropy"),
+		Arg("--l1_loss", action="store_true",
+			help="(only with \"--only_head\" option!) use L1 Hinge Loss instead of Softmax Cross-Entropy"),
 
-			Arg("--from_scratch", action="store_true",
-				help="Do not load any weights. Train the model from scratch"),
+		Arg("--from_scratch", action="store_true",
+			help="Do not load any weights. Train the model from scratch"),
 
-			Arg("--label_shift", type=int, default=1,
-				help="label shift"),
+		Arg("--label_shift", type=int, default=1,
+			help="label shift"),
 
-			Arg("--swap_channels", action="store_true",
-				help="preprocessing option: swap channels from RGB to BGR"),
+		Arg("--swap_channels", action="store_true",
+			help="preprocessing option: swap channels from RGB to BGR"),
 
-			Arg("--label_smoothing", type=float, default=0,
-				help="Factor for label smoothing"),
+		Arg("--label_smoothing", type=float, default=0,
+			help="Factor for label smoothing"),
 
-			Arg("--no_center_crop_on_val", action="store_true",
-				help="do not center crop imaages in the validation step!"),
+		Arg("--no_center_crop_on_val", action="store_true",
+			help="do not center crop imaages in the validation step!"),
 
-			Arg("--only_head", action="store_true", help="fine-tune only last layer"),
-			Arg("--no_progress", action="store_true", help="dont show progress bar"),
-			Arg("--augment", action="store_true", help="do data augmentation (random croping and random hor. flipping)"),
-			Arg("--force_load", action="store_true", help="force loading from caffe model"),
-			Arg("--only_eval", action="store_true", help="evaluate the model only. do not train!"),
-			Arg("--init_eval", action="store_true", help="evaluate the model before training"),
-			Arg("--no_snapshot", action="store_true", help="do not save trained model"),
+		Arg("--only_head", action="store_true", help="fine-tune only last layer"),
+		Arg("--no_progress", action="store_true", help="dont show progress bar"),
+		Arg("--augment", action="store_true", help="do data augmentation (random croping and random hor. flipping)"),
+		Arg("--force_load", action="store_true", help="force loading from caffe model"),
+		Arg("--only_eval", action="store_true", help="evaluate the model only. do not train!"),
+		Arg("--init_eval", action="store_true", help="evaluate the model before training"),
+		Arg("--no_snapshot", action="store_true", help="do not save trained model"),
 
-			Arg("--output", "-o", type=str, default=".out", help="output folder"),
+		Arg("--output", "-o", type=str, default=".out", help="output folder"),
+	]
 
-		])\
+	return ArgFactory(extra_list + arg_list0 + arg_list1)\
 		.seed()\
 		.batch_size()\
 		.epochs()\