utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import numpy as np
  2. from PIL.Image import Image as PIL_Image
  3. DEFAULT_RATIO = np.sqrt(49 / 400)
  4. def __expand_parts(p):
  5. return p[:, 0], p[:, 1:3], p[:, 3].astype(bool)
  6. def rescale_parts(im, parts, part_rescale_size):
  7. if part_rescale_size is None or part_rescale_size < 0:
  8. return parts
  9. h, w, c = dimensions(im)
  10. scale = np.array([w, h]) / part_rescale_size
  11. xy = parts[:, 1:3]
  12. xy = xy * scale
  13. parts[:, 1:3] = xy
  14. if parts.shape[1] == 5:
  15. wh = parts[:, 3:5]
  16. wh = wh * scale
  17. parts[:, 3:5] = wh
  18. return parts
  19. def dimensions(im):
  20. if isinstance(im, np.ndarray):
  21. if im.ndim != 3:
  22. import pdb; pdb.set_trace()
  23. assert im.ndim == 3, "Only RGB images are currently supported!"
  24. return im.shape
  25. elif isinstance(im, PIL_Image):
  26. w, h = im.size
  27. c = len(im.getbands())
  28. # assert c == 3, "Only RGB images are currently supported!"
  29. return h, w, c
  30. else:
  31. raise ValueError("Unknown image instance ({})!".format(type(im)))
  32. def asarray(im, dtype=np.uint8):
  33. if isinstance(im, np.ndarray):
  34. return im.astype(dtype)
  35. elif isinstance(im, PIL_Image):
  36. return np.asarray(im, dtype=dtype)
  37. else:
  38. raise ValueError("Unknown image instance ({})!".format(type(im)))
  39. def uniform_parts(im, ratio=DEFAULT_RATIO, round_op=np.floor):
  40. h, w, c = dimensions(im)
  41. part_w = round_op(w * ratio).astype(np.int32)
  42. part_h = round_op(h * ratio).astype(np.int32)
  43. n, m = w // part_w, h // part_h
  44. parts = np.ones((n*m, 4), dtype=int)
  45. parts[:, 0] = np.arange(n*m)
  46. for x in range(n):
  47. for y in range(m):
  48. i = y * n + x
  49. x0, y0 = x * part_w, y * part_h
  50. parts[i, 1:3] = [x0 + part_w // 2, y0 + part_h // 2]
  51. return parts
  52. def visible_part_locs(p):
  53. idxs, locs, vis = __expand_parts(p)
  54. return idxs[vis], locs[vis].T
  55. def crops(im, xy, ratio=DEFAULT_RATIO, padding_mode="edge"):
  56. h, w, c = dimensions(im)
  57. crop_h, crop_w = int(h * ratio), int(w * ratio)
  58. crops = np.zeros((xy.shape[1], crop_h, crop_w, c), dtype=np.uint8)
  59. pad_h, pad_w = crop_h // 2, crop_w // 2
  60. padded_im = np.pad(im, [(pad_h, pad_h), (pad_w, pad_w), [0,0]], mode=padding_mode)
  61. for i, (x, y) in enumerate(xy.T):
  62. x0, y0 = x - crop_w // 2 + pad_w, y - crop_h // 2 + pad_h
  63. crops[i] = padded_im[y0:y0+crop_h, x0:x0+crop_w]
  64. return crops
  65. def visible_crops(im, p, *args, **kw):
  66. idxs, locs, vis = __expand_parts(p)
  67. parts = crops(asarray(im), locs[vis].T, *args, **kw)
  68. res = np.zeros((len(idxs),) + parts.shape[1:], dtype=parts.dtype)
  69. res[vis] = parts
  70. return res
  71. def reveal_parts(im, xy, ratio=DEFAULT_RATIO):
  72. h, w, c = dimensions(im)
  73. crop_h, crop_w = int(h * ratio), int(w * ratio)
  74. im = asarray(im)
  75. res = np.zeros_like(im)
  76. for x, y in xy.T:
  77. x0, y0 = max(x - crop_w // 2, 0), max(y - crop_h // 2, 0)
  78. res[y0:y0+crop_h, x0:x0+crop_w] = im[y0:y0+crop_h, x0:x0+crop_w]
  79. return res
  80. def select(crops, mask):
  81. selected = np.zeros_like(crops)
  82. selected[mask] = crops[mask]
  83. return selected
  84. def selection_mask(idxs, n):
  85. return np.bincount(idxs, minlength=n).astype(bool)
  86. def random_select(idxs, xy, part_crops, *args, **kw):
  87. rnd_idxs = random_idxs(np.arange(len(idxs)), *args, **kw)
  88. idxs = idxs[rnd_idxs]
  89. xy = xy[:, rnd_idxs]
  90. mask = selection_mask(idxs, len(part_crops))
  91. selected_crops = select(part_crops, mask)
  92. return idxs, xy, selected_crops
  93. def random_idxs(idxs, rnd=None, n_parts=None):
  94. if rnd is None or isinstance(rnd, int):
  95. rnd = np.random.RandomState(rnd)
  96. else:
  97. assert isinstance(rnd, np.random.RandomState), \
  98. "'rnd' should be either a random seed or a RandomState instance!"
  99. n_parts = n_parts or rnd.randint(1, len(idxs))
  100. res = rnd.choice(idxs, n_parts, replace=False)
  101. res.sort()
  102. return res