Browse Source

added color jitter transformation

Dimitri Korsch 4 years ago
parent
commit
a10f1c35bb

+ 6 - 2
cvdatasets/dataset/mixins/transform.py

@@ -1,7 +1,8 @@
 import abc
+import chainer
 
-from cvdatasets.dataset.mixins.base import BaseMixin
 from cvdatasets.dataset.image.size import Size
+from cvdatasets.dataset.mixins.base import BaseMixin
 
 class TransformMixin(BaseMixin):
 
@@ -20,7 +21,10 @@ class TransformMixin(BaseMixin):
 
 	@property
 	def size(self):
-		return self._size
+		if chainer.config.train:
+			return self._size // 0.875
+		else:
+			return self._size
 
 	@size.setter
 	def size(self, value):

+ 3 - 2
cvdatasets/utils/__init__.py

@@ -65,6 +65,7 @@ class _MetaInfo(object):
 
 from cvdatasets.utils.dataset import new_iterator
 from cvdatasets.utils.image import asarray
-from cvdatasets.utils.image import dimensions
-from cvdatasets.utils.image import rescale
 from cvdatasets.utils.image import read_image
+from cvdatasets.utils.transforms import color_jitter
+from cvdatasets.utils.transforms import dimensions
+from cvdatasets.utils.transforms import rescale

+ 2 - 29
cvdatasets/utils/image.py

@@ -20,39 +20,12 @@ def read_image(im_path, n_retries=5):
 		raise RuntimeError("Reading image \"{}\" failed after {} n_retries! ({})".format(im_path, n_retries, error))
 
 
-def rescale(im, coords, rescale_size, center_cropped=True, no_offset=False):
-	h, w, c = dimensions(im)
-
-	offset = 0
-	if center_cropped:
-		_min_val = min(w, h)
-		wh = np.array([_min_val, _min_val])
-		if not no_offset:
-			offset = (np.array([w, h]) - wh) / 2
-	else:
-		wh = np.array([w, h])
-
-	scale = wh / rescale_size
-	return coords * scale + offset
-
-def dimensions(im):
-	if isinstance(im, np.ndarray):
-		if im.ndim != 3:
-			import pdb; pdb.set_trace()
-		assert im.ndim == 3, "Only RGB images are currently supported!"
-		return im.shape
-	elif isinstance(im, Image.Image):
-		w, h = im.size
-		c = len(im.getbands())
-		# assert c == 3, "Only RGB images are currently supported!"
-		return h, w, c
-	else:
-		raise ValueError("Unknown image instance ({})!".format(type(im)))
-
 def asarray(im, dtype=np.uint8):
 	if isinstance(im, np.ndarray):
 		return im.astype(dtype)
+
 	elif isinstance(im, Image.Image):
 		return np.asarray(im, dtype=dtype)
+
 	else:
 		raise ValueError("Unknown image instance ({})!".format(type(im)))

+ 130 - 0
cvdatasets/utils/transforms.py

@@ -0,0 +1,130 @@
+import numpy as np
+import random
+
+from PIL import Image
+
+
+def dimensions(im):
+	if isinstance(im, np.ndarray):
+		if im.ndim != 3:
+			import pdb; pdb.set_trace()
+		assert im.ndim == 3, "Only RGB images are currently supported!"
+		return im.shape
+
+	elif isinstance(im, Image.Image):
+		w, h = im.size
+		c = len(im.getbands())
+		# assert c == 3, "Only RGB images are currently supported!"
+		return h, w, c
+
+	else:
+		raise ValueError("Unknown image instance ({})!".format(type(im)))
+
+def rescale(im, coords, rescale_size, center_cropped=True, no_offset=False):
+	h, w, c = dimensions(im)
+
+	offset = 0
+	if center_cropped:
+		_min_val = min(w, h)
+		wh = np.array([_min_val, _min_val])
+		if not no_offset:
+			offset = (np.array([w, h]) - wh) / 2
+
+	else:
+		wh = np.array([w, h])
+
+	scale = wh / rescale_size
+	return coords * scale + offset
+
+####################
+### Source: https://github.com/chainer/chainercv/blob/b52c71d9cd11dc9efdd5aaf327fed1a99df94d10/chainercv/transforms/image/color_jitter.py
+####################
+
+
+def _grayscale(img):
+	out = np.zeros_like(img)
+	out[:] = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]
+	return out
+
+
+def _blend(img_a, img_b, alpha):
+	return alpha * img_a + (1 - alpha) * img_b
+
+
+def _brightness(img, var):
+	alpha = 1 + np.random.uniform(-var, var)
+	return _blend(img, np.zeros_like(img), alpha), alpha
+
+
+def _contrast(img, var):
+	gray = _grayscale(img)
+	gray.fill(gray[0].mean())
+
+	alpha = 1 + np.random.uniform(-var, var)
+	return _blend(img, gray, alpha), alpha
+
+
+def _saturation(img, var):
+	gray = _grayscale(img)
+
+	alpha = 1 + np.random.uniform(-var, var)
+	return _blend(img, gray, alpha), alpha
+
+
+def color_jitter(img, brightness=0.4, contrast=0.4,
+				 saturation=0.4, return_param=False, max_value=255):
+	"""Data augmentation on brightness, contrast and saturation.
+	Args:
+		img (~numpy.ndarray): An image array to be augmented. This is in
+			CHW and RGB format.
+		brightness (float): Alpha for brightness is sampled from
+			:obj:`unif(-brightness, brightness)`. The default
+			value is 0.4.
+		contrast (float): Alpha for contrast is sampled from
+			:obj:`unif(-contrast, contrast)`. The default
+			value is 0.4.
+		saturation (float): Alpha for contrast is sampled from
+			:obj:`unif(-saturation, saturation)`. The default
+			value is 0.4.
+		return_param (bool): Returns parameters if :obj:`True`.
+	Returns:
+		~numpy.ndarray or (~numpy.ndarray, dict):
+		If :obj:`return_param = False`,
+		returns an color jittered image.
+		If :obj:`return_param = True`, returns a tuple of an array and a
+		dictionary :obj:`param`.
+		:obj:`param` is a dictionary of intermediate parameters whose
+		contents are listed below with key, value-type and the description
+		of the value.
+		* **order** (*list of strings*): List containing three strings: \
+			:obj:`'brightness'`, :obj:`'contrast'` and :obj:`'saturation'`. \
+			They are ordered according to the order in which the data \
+			augmentation functions are applied.
+		* **brightness_alpha** (*float*): Alpha used for brightness \
+			data augmentation.
+		* **contrast_alpha** (*float*): Alpha used for contrast \
+			data augmentation.
+		* **saturation_alpha** (*float*): Alpha used for saturation \
+			data augmentation.
+	"""
+	funcs = list()
+	if brightness > 0:
+		funcs.append(('brightness', lambda x: _brightness(x, brightness)))
+	if contrast > 0:
+		funcs.append(('contrast', lambda x: _contrast(x, contrast)))
+	if saturation > 0:
+		funcs.append(('saturation', lambda x: _saturation(x, saturation)))
+	random.shuffle(funcs)
+
+	params = {'order': [key for key, val in funcs],
+			  'brightness_alpha': 1,
+			  'contrast_alpha': 1,
+			  'saturation_alpha': 1}
+	for key, func in funcs:
+		img, alpha = func(img)
+		params[key + '_alpha'] = alpha
+	img = np.minimum(np.maximum(img, 0), max_value)
+	if return_param:
+		return img, params
+	else:
+		return img