|
@@ -1,6 +1,15 @@
|
|
|
import numpy as np
|
|
|
import random
|
|
|
import typing as T
|
|
|
+import warnings
|
|
|
+
|
|
|
+try:
|
|
|
+ import cv2
|
|
|
+except ImportError as e:
|
|
|
+ warnings.warn("OpencCV was not installed! Some of the image operations won't use the 'use_cv' flag")
|
|
|
+ cv2_available = False
|
|
|
+else:
|
|
|
+ cv2_available = True
|
|
|
|
|
|
from PIL import Image
|
|
|
from functools import partial
|
|
@@ -62,25 +71,47 @@ def rescale(im, coords, rescale_size, center_cropped=True, no_offset=False):
|
|
|
####################
|
|
|
### Source: https://github.com/chainer/chainercv/blob/b52c71d9cd11dc9efdd5aaf327fed1a99df94d10/chainercv/transforms/image/color_jitter.py
|
|
|
####################
|
|
|
+# from https://scikit-image.org/docs/dev/api/skimage.color.html#skimage.color.rgb2gray:
|
|
|
+# Y = 0.2125 R + 0.7154 G + 0.0721 B
|
|
|
+GRAY_WEIGHTS = (0.2125, 0.7154, 0.0721)
|
|
|
|
|
|
+# from https://docs.opencv.org/4.7.0/de/d25/imgproc_color_conversions.html:
|
|
|
+# Y = 0.299 R + 0.587 G + 0.114 B
|
|
|
+GRAY_WEIGHTS = (0.299, 0.587, 0.114)
|
|
|
|
|
|
-def _grayscale(img, channel_order="RGB"):
|
|
|
+def _grayscale(img, *, channel_order="RGB", use_cv=True, axis_order="CHW"):
|
|
|
"""
|
|
|
- from https://scikit-image.org/docs/dev/api/skimage.color.html#skimage.color.rgb2gray:
|
|
|
- Y = 0.2125 R + 0.7154 G + 0.0721 B
|
|
|
+ channel_order can be either 'RGB' or 'BGR'
|
|
|
+ axis_order can be either 'CHW' or 'HWC'
|
|
|
"""
|
|
|
+ global GRAY_WEIGHTS
|
|
|
|
|
|
- if channel_order == "RGB":
|
|
|
- return 0.2125 * img[0] + 0.7154 * img[1] + 0.0721 * img[2]
|
|
|
+ assert channel_order in ["RGB", "BGR"], f"Unknown channel order: {channel_order}"
|
|
|
+ assert axis_order in ["CHW", "HWC"], f"Unknown axis order: {axis_order}"
|
|
|
|
|
|
- elif channel_order == "BGR":
|
|
|
- return 0.0721 * img[0] + 0.7154 * img[1] + 0.2125 * img[2]
|
|
|
+ w = GRAY_WEIGHTS if channel_order == "RGB" else reversed(GRAY_WEIGHTS)
|
|
|
|
|
|
- else:
|
|
|
- raise ValueError(f"Unknown channel order: {channel_order}")
|
|
|
+ if use_cv and cv2_available and axis_order == "HWC":
|
|
|
+ mode = cv2.COLOR_RGB2GRAY if channel_order == "RGB" else cv2.COLOR_BGR2GRAY
|
|
|
+ return cv2.cvtColor(img, mode)
|
|
|
+
|
|
|
+ if axis_order == "HWC":
|
|
|
+ R, G, B = img[..., 0], img[..., 1], img[..., 2]
|
|
|
+
|
|
|
+ elif axis_order == "CHW":
|
|
|
+ R, G, B = img[0], img[1], img[2]
|
|
|
|
|
|
+ if use_cv and cv2_available:
|
|
|
+ return cv2.addWeighted(cv2.addWeighted(R, w[0], G, w[1], 0), 1, B, w[2], 0)
|
|
|
|
|
|
-def _blend(img_a, img_b, alpha):
|
|
|
+ return w[0] * R + w[1] * G + w[2] * B
|
|
|
+
|
|
|
+def _blend(img_a, img_b, alpha, *, use_cv=True):
|
|
|
+ if use_cv and cv2_available:
|
|
|
+ if img_a.shape == img_b.shape:
|
|
|
+ return cv2.addWeighted(img_a,alpha, img_b,1-alpha,0)
|
|
|
+ # else:
|
|
|
+ # import pdb; pdb.set_trace()
|
|
|
return alpha * img_a + (1 - alpha) * img_b
|
|
|
|
|
|
|
|
@@ -169,3 +200,46 @@ def color_jitter(img, brightness=0.4, contrast=0.4,
|
|
|
return img, params
|
|
|
else:
|
|
|
return img
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ import time
|
|
|
+ from tqdm.auto import tqdm
|
|
|
+ from functools import partial
|
|
|
+
|
|
|
+ shape, order = (3, 600, 600), "CHW"
|
|
|
+ # shape, order = (600, 600, 3), "HWC"
|
|
|
+ small_shape = (max(shape), max(shape))
|
|
|
+ iterations = 6_000
|
|
|
+
|
|
|
+ im = np.random.randint(0, 255, size=shape, dtype=np.uint8)
|
|
|
+ im2 = np.random.randint(0, 255, size=shape, dtype=np.uint8)
|
|
|
+ im3 = np.random.randint(0, 255, size=small_shape, dtype=np.uint8)
|
|
|
+
|
|
|
+ res0 = _blend(im, im2, alpha=0.5, use_cv=True).astype(np.int32)
|
|
|
+ res1 = _blend(im, im2, alpha=0.5, use_cv=False).astype(np.int32)
|
|
|
+ diff = np.abs(res0-res1)
|
|
|
+ assert np.all(diff <= 2)
|
|
|
+
|
|
|
+
|
|
|
+ res0 = _grayscale(im, use_cv=True).astype(np.int32)
|
|
|
+ res1 = _grayscale(im, use_cv=False).astype(np.int32)
|
|
|
+ diff = np.abs(res0-res1)
|
|
|
+ assert np.all(diff <= 2)
|
|
|
+
|
|
|
+ funcs = [
|
|
|
+ (partial(_grayscale, axis_order=order), f"Grayscale on {order}"),
|
|
|
+ (partial(_blend, img_b=im2, alpha=0.5), "Blending with same size"),
|
|
|
+ (partial(_blend, img_b=im3, alpha=0.5), "Blending with smaller size"),
|
|
|
+ ]
|
|
|
+
|
|
|
+ for func, func_name in funcs:
|
|
|
+ for use_cv in [True, False]:
|
|
|
+ desc = f"{func_name} with{'' if use_cv else 'out'} OpenCV"
|
|
|
+ t0 = time.time()
|
|
|
+ for n in tqdm(range(iterations), desc=desc):
|
|
|
+ res = func(im, use_cv=use_cv)
|
|
|
+ t = time.time() - t0
|
|
|
+
|
|
|
+ print(f"{iterations} iterations took {t:.3f}sec ({t/iterations * 1000:.3f} ms/iter)")
|