Browse Source

added some augmentation options

Dimitri Korsch 5 years ago
parent
commit
73267a952b
2 changed files with 34 additions and 30 deletions
  1. 12 24
      cvfinetune/finetuner/base.py
  2. 22 6
      cvfinetune/parser/training_args.py

+ 12 - 24
cvfinetune/finetuner/base.py

@@ -21,6 +21,7 @@ from chainer_addons.training import optimizer_hooks
 from cvdatasets import AnnotationType
 from cvdatasets.utils import new_iterator
 from cvdatasets.utils import pretty_print_dict
+from cvdatasets.dataset.image import Size
 
 from functools import partial
 from os.path import join
@@ -112,14 +113,8 @@ class _ModelMixin(abc.ABC):
 
 		self.model = ModelType.new(
 			model_type=self.model_info.class_key,
-			input_size=opts.input_size,
+			input_size=Size(opts.input_size),
 			**self.model_kwargs,
-			# pooling=opts.pooling,
-			# pooling_params=dict(
-			# 	init_alpha=opts.init_alpha,
-			# 	output_dim=8192,
-			# 	normalize=opts.normalize),
-			# aux_logits=False
 		)
 
 	def load_model_weights(self, args):
@@ -183,10 +178,10 @@ class _DatasetMixin(abc.ABC):
 	def n_classes(self):
 		return self.ds_info.n_classes + self.dataset_cls.label_shift
 
-	def new_dataset(self, opts, size, subset, augment):
+	def new_dataset(self, opts, size, subset):
 		"""Creates a dataset for a specific subset and certain options"""
 		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
-			kwargs = self.dataset_kwargs_factory(opts, subset, augment)
+			kwargs = self.dataset_kwargs_factory(opts, subset)
 		else:
 			kwargs = dict()
 
@@ -195,25 +190,18 @@ class _DatasetMixin(abc.ABC):
 			dataset_cls=self.dataset_cls,
 		))
 
-		# if opts.use_parts:
-		# 	kwargs.update(dict(
-		# 		no_glob=opts.no_global,
-		# 	))
 
 		if not getattr(opts, "only_head", False):
 			kwargs.update(dict(
-				preprocess=self.prepare,
-				augment=augment,
+				prepare=self.prepare,
 				size=size,
-				center_crop_on_val=not getattr(opts, "no_center_crop_on_val", False),
+				center_crop_on_val=getattr(opts, "center_crop_on_val", False),
 
 			))
 
-		d = self.annot.new_dataset(**kwargs)
-		logging.info("Loaded {} images".format(len(d)))
-		logging.info("Data augmentation is {}abled".format("en" if augment else "dis"))
-		# logging.info("Global feature is {}used".format("not " if opts.no_global else ""))
-		return d
+		ds = self.annot.new_dataset(**kwargs)
+		logging.info("Loaded {} images".format(len(ds)))
+		return ds
 
 	def init_annotations(self, opts):
 		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
@@ -237,7 +225,7 @@ class _DatasetMixin(abc.ABC):
 
 		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
 			swap_channels=opts.swap_channels,
-			keep_ratio=not getattr(opts, "no_center_crop_on_val", False),
+			keep_ratio=getattr(opts, "center_crop_on_val", False),
 		)
 
 		logging.info(" ".join([
@@ -246,8 +234,8 @@ class _DatasetMixin(abc.ABC):
 			f"Image input size: {size}",
 		]))
 
-		self.train_data = self.new_dataset(opts, size, "train", True)
-		self.val_data = self.new_dataset(opts, size, "test", False)
+		self.train_data = self.new_dataset(opts, size, "train")
+		self.val_data = self.new_dataset(opts, size, "test")
 
 	def init_iterators(self, opts):
 		"""Creates training and validation iterators from training and validation datasets"""

+ 22 - 6
cvfinetune/parser/training_args.py

@@ -29,11 +29,7 @@ def add_training_args(parser):
 		Arg("--label_smoothing", type=float, default=0,
 			help="Factor for label smoothing"),
 
-		Arg("--no_center_crop_on_val", action="store_true",
-			help="do not center crop images in the validation step!"),
-
 		Arg("--only_head", action="store_true", help="fine-tune only last layer"),
-		Arg("--augment", action="store_true", help="do data augmentation (random croping and random hor. flipping)"),
 
 	])\
 	.seed()\
@@ -46,8 +42,28 @@ def add_training_args(parser):
 	parser.add_args(_args, group_name="Training arguments")
 
 	_args = [
-			Arg("--only_eval", action="store_true", help="evaluate the model only. do not train!"),
-			Arg("--init_eval", action="store_true", help="evaluate the model before training"),
+		Arg("--augmentations",
+			choices=[
+				"random_crop",
+				"random_flip",
+				"random_rotation",
+				"center_crop",
+				"color_jitter"
+			],
+			default=["random_crop", "random_flip", "color_jitter"],
+			nargs="*"),
+
+		Arg("--center_crop_on_val", action="store_true"),
+		Arg("--brightness_jitter", type=int, default=0.3),
+		Arg("--contrast_jitter", type=int, default=0.3),
+		Arg("--saturation_jitter", type=int, default=0.3),
+
+	]
+	parser.add_args(_args, group_name="Augmentation arguments")
+
+	_args = [
+		Arg("--only_eval", action="store_true", help="evaluate the model only. do not train!"),
+		Arg("--init_eval", action="store_true", help="evaluate the model before training"),
 	]
 
 	parser.add_args(_args, group_name="Evaluation arguments")