|
@@ -2,7 +2,7 @@ import numpy as np
|
|
|
import random
|
|
|
|
|
|
from PIL import Image
|
|
|
-
|
|
|
+from functools import partial
|
|
|
|
|
|
def dimensions(im):
|
|
|
if isinstance(im, np.ndarray):
|
|
@@ -41,10 +41,20 @@ def rescale(im, coords, rescale_size, center_cropped=True, no_offset=False):
|
|
|
####################
|
|
|
|
|
|
|
|
|
-def _grayscale(img):
|
|
|
- out = np.zeros_like(img)
|
|
|
- out[:] = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]
|
|
|
- return out
|
|
|
+def _grayscale(img, channel_order="RGB"):
|
|
|
+ """
|
|
|
+ from https://scikit-image.org/docs/dev/api/skimage.color.html#skimage.color.rgb2gray:
|
|
|
+ Y = 0.2125 R + 0.7154 G + 0.0721 B
|
|
|
+ """
|
|
|
+
|
|
|
+ if channel_order == "RGB":
|
|
|
+ return 0.2125 * img[0] + 0.7154 * img[1] + 0.0721 * img[2]
|
|
|
+
|
|
|
+ elif channel_order == "BGR":
|
|
|
+ return 0.0721 * img[0] + 0.7154 * img[1] + 0.2125 * img[2]
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unknown channel order: {channel_order}")
|
|
|
|
|
|
|
|
|
def _blend(img_a, img_b, alpha):
|
|
@@ -56,23 +66,25 @@ def _brightness(img, var):
|
|
|
return _blend(img, np.zeros_like(img), alpha), alpha
|
|
|
|
|
|
|
|
|
-def _contrast(img, var):
|
|
|
- gray = _grayscale(img)
|
|
|
- gray.fill(gray[0].mean())
|
|
|
+def _contrast(img, var, **kwargs):
|
|
|
+ gray = _grayscale(img, **kwargs)[0].mean()
|
|
|
|
|
|
alpha = 1 + np.random.uniform(-var, var)
|
|
|
return _blend(img, gray, alpha), alpha
|
|
|
|
|
|
|
|
|
-def _saturation(img, var):
|
|
|
- gray = _grayscale(img)
|
|
|
+def _saturation(img, var, **kwargs):
|
|
|
+ gray = _grayscale(img, **kwargs)
|
|
|
|
|
|
alpha = 1 + np.random.uniform(-var, var)
|
|
|
return _blend(img, gray, alpha), alpha
|
|
|
|
|
|
|
|
|
def color_jitter(img, brightness=0.4, contrast=0.4,
|
|
|
- saturation=0.4, return_param=False, max_value=255):
|
|
|
+ saturation=0.4, return_param=False,
|
|
|
+ min_value=0,
|
|
|
+ max_value=255,
|
|
|
+ channel_order="RGB"):
|
|
|
"""Data augmentation on brightness, contrast and saturation.
|
|
|
Args:
|
|
|
img (~numpy.ndarray): An image array to be augmented. This is in
|
|
@@ -109,11 +121,11 @@ def color_jitter(img, brightness=0.4, contrast=0.4,
|
|
|
"""
|
|
|
funcs = list()
|
|
|
if brightness > 0:
|
|
|
- funcs.append(('brightness', lambda x: _brightness(x, brightness)))
|
|
|
+ funcs.append(('brightness', partial(_brightness, var=brightness)))
|
|
|
if contrast > 0:
|
|
|
- funcs.append(('contrast', lambda x: _contrast(x, contrast)))
|
|
|
+ funcs.append(('contrast', partial(_contrast, var=contrast, channel_order=channel_order)))
|
|
|
if saturation > 0:
|
|
|
- funcs.append(('saturation', lambda x: _saturation(x, saturation)))
|
|
|
+ funcs.append(('saturation', partial(_saturation, var=saturation, channel_order=channel_order)))
|
|
|
random.shuffle(funcs)
|
|
|
|
|
|
params = {'order': [key for key, val in funcs],
|
|
@@ -123,7 +135,13 @@ def color_jitter(img, brightness=0.4, contrast=0.4,
|
|
|
for key, func in funcs:
|
|
|
img, alpha = func(img)
|
|
|
params[key + '_alpha'] = alpha
|
|
|
- img = np.minimum(np.maximum(img, 0), max_value)
|
|
|
+
|
|
|
+ if min_value is not None:
|
|
|
+ img = np.maximum(img, min_value)
|
|
|
+
|
|
|
+ if max_value is not None:
|
|
|
+ img = np.minimum(img, max_value)
|
|
|
+
|
|
|
if return_param:
|
|
|
return img, params
|
|
|
else:
|