浏览代码

some additions and performance improvements

Dimitri Korsch 4 年之前
父节点
当前提交
c8f852d49b
共有 1 个文件被更改,包括 33 次插入15 次删除
  1. 33 15
      cvdatasets/utils/transforms.py

+ 33 - 15
cvdatasets/utils/transforms.py

@@ -2,7 +2,7 @@ import numpy as np
 import random
 
 from PIL import Image
-
+from functools import partial
 
 def dimensions(im):
 	if isinstance(im, np.ndarray):
@@ -41,10 +41,20 @@ def rescale(im, coords, rescale_size, center_cropped=True, no_offset=False):
 ####################
 
 
-def _grayscale(img):
-	out = np.zeros_like(img)
-	out[:] = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]
-	return out
+def _grayscale(img, channel_order="RGB"):
+	"""
+		from https://scikit-image.org/docs/dev/api/skimage.color.html#skimage.color.rgb2gray:
+			Y = 0.2125 R + 0.7154 G + 0.0721 B
+	"""
+
+	if channel_order == "RGB":
+		return 0.2125 * img[0] + 0.7154 * img[1] + 0.0721 * img[2]
+
+	elif channel_order == "BGR":
+		return 0.0721 * img[0] + 0.7154 * img[1] + 0.2125 * img[2]
+
+	else:
+		raise ValueError(f"Unknown channel order: {channel_order}")
 
 
 def _blend(img_a, img_b, alpha):
@@ -56,23 +66,25 @@ def _brightness(img, var):
 	return _blend(img, np.zeros_like(img), alpha), alpha
 
 
-def _contrast(img, var):
-	gray = _grayscale(img)
-	gray.fill(gray[0].mean())
+def _contrast(img, var, **kwargs):
+	gray = _grayscale(img, **kwargs)[0].mean()
 
 	alpha = 1 + np.random.uniform(-var, var)
 	return _blend(img, gray, alpha), alpha
 
 
-def _saturation(img, var):
-	gray = _grayscale(img)
+def _saturation(img, var, **kwargs):
+	gray = _grayscale(img, **kwargs)
 
 	alpha = 1 + np.random.uniform(-var, var)
 	return _blend(img, gray, alpha), alpha
 
 
 def color_jitter(img, brightness=0.4, contrast=0.4,
-				 saturation=0.4, return_param=False, max_value=255):
+				 saturation=0.4, return_param=False,
+				 min_value=0,
+				 max_value=255,
+				 channel_order="RGB"):
 	"""Data augmentation on brightness, contrast and saturation.
 	Args:
 		img (~numpy.ndarray): An image array to be augmented. This is in
@@ -109,11 +121,11 @@ def color_jitter(img, brightness=0.4, contrast=0.4,
 	"""
 	funcs = list()
 	if brightness > 0:
-		funcs.append(('brightness', lambda x: _brightness(x, brightness)))
+		funcs.append(('brightness', partial(_brightness, var=brightness)))
 	if contrast > 0:
-		funcs.append(('contrast', lambda x: _contrast(x, contrast)))
+		funcs.append(('contrast', partial(_contrast, var=contrast, channel_order=channel_order)))
 	if saturation > 0:
-		funcs.append(('saturation', lambda x: _saturation(x, saturation)))
+		funcs.append(('saturation', partial(_saturation, var=saturation, channel_order=channel_order)))
 	random.shuffle(funcs)
 
 	params = {'order': [key for key, val in funcs],
@@ -123,7 +135,13 @@ def color_jitter(img, brightness=0.4, contrast=0.4,
 	for key, func in funcs:
 		img, alpha = func(img)
 		params[key + '_alpha'] = alpha
-	img = np.minimum(np.maximum(img, 0), max_value)
+
+	if min_value is not None:
+		img = np.maximum(img, min_value)
+
+	if max_value is not None:
+		img = np.minimum(img, max_value)
+
 	if return_param:
 		return img, params
 	else: