Browse Source

enhanced the image transformations used by the augmentations (especially by color jittering)

Dimitri Korsch 2 years ago
parent
commit
100ee3b9a0
1 changed files with 84 additions and 10 deletions
  1. 84 10
      cvdatasets/utils/transforms.py

+ 84 - 10
cvdatasets/utils/transforms.py

@@ -1,6 +1,15 @@
 import numpy as np
 import random
 import typing as T
+import warnings
+
+try:
+	import cv2
+except ImportError as e:
+	warnings.warn("OpencCV was not installed! Some of the image operations won't use the 'use_cv' flag")
+	cv2_available = False
+else:
+	cv2_available = True
 
 from PIL import Image
 from functools import partial
@@ -62,25 +71,47 @@ def rescale(im, coords, rescale_size, center_cropped=True, no_offset=False):
 ####################
 ### Source: https://github.com/chainer/chainercv/blob/b52c71d9cd11dc9efdd5aaf327fed1a99df94d10/chainercv/transforms/image/color_jitter.py
 ####################
+# from https://scikit-image.org/docs/dev/api/skimage.color.html#skimage.color.rgb2gray:
+#	Y = 0.2125 R + 0.7154 G + 0.0721 B
+GRAY_WEIGHTS = (0.2125, 0.7154, 0.0721)
 
+#	from https://docs.opencv.org/4.7.0/de/d25/imgproc_color_conversions.html:
+#	Y = 0.299 R + 0.587 G + 0.114 B
+GRAY_WEIGHTS = (0.299, 0.587, 0.114)
 
-def _grayscale(img, channel_order="RGB"):
+def _grayscale(img, *, channel_order="RGB", use_cv=True, axis_order="CHW"):
 	"""
-		from https://scikit-image.org/docs/dev/api/skimage.color.html#skimage.color.rgb2gray:
-			Y = 0.2125 R + 0.7154 G + 0.0721 B
+		channel_order can be either 'RGB' or 'BGR'
+		axis_order can be either 'CHW' or 'HWC'
 	"""
+	global GRAY_WEIGHTS
 
-	if channel_order == "RGB":
-		return 0.2125 * img[0] + 0.7154 * img[1] + 0.0721 * img[2]
+	assert channel_order in ["RGB", "BGR"], f"Unknown channel order: {channel_order}"
+	assert axis_order in ["CHW", "HWC"], f"Unknown axis order: {axis_order}"
 
-	elif channel_order == "BGR":
-		return 0.0721 * img[0] + 0.7154 * img[1] + 0.2125 * img[2]
+	w = GRAY_WEIGHTS if channel_order == "RGB" else reversed(GRAY_WEIGHTS)
 
-	else:
-		raise ValueError(f"Unknown channel order: {channel_order}")
+	if use_cv and cv2_available and axis_order == "HWC":
+		mode = cv2.COLOR_RGB2GRAY if channel_order == "RGB" else cv2.COLOR_BGR2GRAY
+		return cv2.cvtColor(img, mode)
+
+	if axis_order == "HWC":
+		R, G, B = img[..., 0], img[..., 1], img[..., 2]
+
+	elif axis_order == "CHW":
+		R, G, B = img[0], img[1], img[2]
 
+	if use_cv and cv2_available:
+		return cv2.addWeighted(cv2.addWeighted(R, w[0], G, w[1], 0), 1, B, w[2], 0)
 
-def _blend(img_a, img_b, alpha):
+	return w[0] * R + w[1] * G + w[2] * B
+
+def _blend(img_a, img_b, alpha, *, use_cv=True):
+	if use_cv and cv2_available:
+		if img_a.shape == img_b.shape:
+			return cv2.addWeighted(img_a,alpha, img_b,1-alpha,0)
+		# else:
+		# 	import pdb; pdb.set_trace()
 	return alpha * img_a + (1 - alpha) * img_b
 
 
@@ -169,3 +200,46 @@ def color_jitter(img, brightness=0.4, contrast=0.4,
 		return img, params
 	else:
 		return img
+
+
+
+if __name__ == '__main__':
+	import time
+	from tqdm.auto import tqdm
+	from functools import partial
+
+	shape, order = (3, 600, 600), "CHW"
+	# shape, order = (600, 600, 3), "HWC"
+	small_shape = (max(shape), max(shape))
+	iterations = 6_000
+
+	im = np.random.randint(0, 255, size=shape, dtype=np.uint8)
+	im2 = np.random.randint(0, 255, size=shape, dtype=np.uint8)
+	im3 = np.random.randint(0, 255, size=small_shape, dtype=np.uint8)
+
+	res0 = _blend(im, im2, alpha=0.5, use_cv=True).astype(np.int32)
+	res1 = _blend(im, im2, alpha=0.5, use_cv=False).astype(np.int32)
+	diff = np.abs(res0-res1)
+	assert np.all(diff <= 2)
+
+
+	res0 = _grayscale(im, use_cv=True).astype(np.int32)
+	res1 = _grayscale(im, use_cv=False).astype(np.int32)
+	diff = np.abs(res0-res1)
+	assert np.all(diff <= 2)
+
+	funcs = [
+		(partial(_grayscale, axis_order=order), f"Grayscale on {order}"),
+		(partial(_blend, img_b=im2, alpha=0.5), "Blending with same size"),
+		(partial(_blend, img_b=im3, alpha=0.5), "Blending with smaller size"),
+	]
+
+	for func, func_name in funcs:
+		for use_cv in [True, False]:
+			desc = f"{func_name} with{'' if use_cv else 'out'} OpenCV"
+			t0 = time.time()
+			for n in tqdm(range(iterations), desc=desc):
+				res = func(im, use_cv=use_cv)
+			t = time.time() - t0
+
+			print(f"{iterations} iterations took {t:.3f}sec ({t/iterations * 1000:.3f} ms/iter)")