Эх сурвалжийг харах

some changes for in the argument creation and handling

Dimitri Korsch 5 жил өмнө
parent
commit
c2c4f52df5

+ 18 - 10
cvfinetune/finetuner/base.py

@@ -60,20 +60,23 @@ class _ModelMixin(abc.ABC):
 		]))
 
 	def _loss_func(self, opts):
-		if opts.l1_loss:
+		if getattr(opts, "l1_loss", False):
 			return F.hinge
 
-		elif opts.label_smoothing >= 0:
-			assert opts.label_smoothing < 1, \
+		elif getattr(opts, "label_smoothing", 0) >= 0:
+			assert getattr(opts, "label_smoothing", 0) < 1, \
 				"Label smoothing factor must be less than 1!"
 			return partial(smoothed_cross_entropy,
 				N=self.n_classes,
-				eps=opts.label_smoothing)
+				eps=getattr(opts, "label_smoothing", 0))
 		else:
 			return F.softmax_cross_entropy
 
 	def init_optimizer(self, opts):
 		"""Creates an optimizer for the classifier """
+		if not hasattr(opts, "optimizer"):
+			self.opt = None
+			return
 
 		opt_kwargs = {}
 		if opts.optimizer == "rmsprop":
@@ -100,7 +103,7 @@ class _ModelMixin(abc.ABC):
 			logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
 			self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
 
-		if opts.only_head:
+		if getattr(opts, "only_head", False):
 			assert not opts.recurrent, "FIX ME! Not supported yet!"
 
 			logging.warning("========= Fine-tuning only classifier layer! =========")
@@ -123,7 +126,7 @@ class _ModelMixin(abc.ABC):
 		)
 
 	def load_model_weights(self, args):
-		if args.from_scratch:
+		if getattr(args, "from_scratch", False):
 			logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
 			loader = self.model.reinitialize_clf
 			self.weights = None
@@ -200,12 +203,12 @@ class _DatasetMixin(abc.ABC):
 		# 		no_glob=opts.no_global,
 		# 	))
 
-		if not opts.only_head:
+		if not getattr(opts, "only_head", False):
 			kwargs.update(dict(
 				preprocess=self.prepare,
 				augment=augment,
 				size=size,
-				center_crop_on_val=not opts.no_center_crop_on_val,
+				center_crop_on_val=not getattr(opts, "no_center_crop_on_val", False),
 
 			))
 
@@ -225,7 +228,7 @@ class _DatasetMixin(abc.ABC):
 		self.model_info = self.data_info.MODELS[opts.model_type]
 		self.part_info = self.data_info.PARTS[opts.parts]
 
-		if opts.only_head:
+		if getattr(opts, "only_head", False):
 			self.annot.feature_model = opts.model_type
 
 		self.dataset_cls.label_shift = opts.label_shift
@@ -237,7 +240,7 @@ class _DatasetMixin(abc.ABC):
 
 		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
 			swap_channels=opts.swap_channels,
-			keep_ratio=not opts.no_center_crop_on_val,
+			keep_ratio=not getattr(opts, "no_center_crop_on_val", False),
 		)
 
 		logging.info(" ".join([
@@ -281,6 +284,11 @@ class _TrainerMixin(abc.ABC):
 
 	def init_updater(self):
 		"""Creates an updater from training iterator and the optimizer."""
+
+		if self.opt is None:
+			self.updater = None
+			return
+
 		self.updater = self.updater_cls(
 			iterator=self.train_iter,
 			optimizer=self.opt,

+ 3 - 0
cvfinetune/parser/dataset_args.py

@@ -30,6 +30,9 @@ def add_dataset_args(parser):
 		Arg("--swap_channels", action="store_true",
 			help="preprocessing option: swap channels from RGB to BGR"),
 
+		Arg("--n_jobs", "-j", type=int, default=0,
+			help="number of loading processes. If 0, then images are loaded in the same process"),
+
 	])
 
 	parser.add_args(_args, group_name="Dataset arguments")

+ 0 - 3
cvfinetune/parser/training_args.py

@@ -11,9 +11,6 @@ def add_training_args(parser):
 
 	_args = ArgFactory([
 
-		Arg("--n_jobs", "-j", type=int, default=0,
-			help="number of loading processes. If 0, then images are loaded in the same process"),
-
 		Arg("--warm_up", type=int, help="warm up epochs"),
 
 		OptimizerType.as_arg("optimizer", "opt",