dataset.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from imageio import imread
  2. import numpy as np
  3. class Dataset(object):
  4. def __init__(self, uuids, annotations):
  5. super(Dataset, self).__init__()
  6. self.uuids = uuids
  7. self._annot = annotations
  8. def __len__(self):
  9. return len(self.uuids)
  10. def _get(self, method, i):
  11. return getattr(self._annot, method)(self.uuids[i])
  12. def get_example(self, i, mode="RGB"):
  13. methods = ["image", "parts", "label"]
  14. im_path, parts, label = [self._get(m, i) for m in methods]
  15. return imread(im_path, pilmode=mode), parts, label
  16. __getitem__ = get_example
  17. # some convention functions
  18. DEFAULT_RATIO = np.sqrt(49 / 400)
  19. def __expand_parts(p):
  20. return p[:, 0], p[:, 1:3], p[:, 3].astype(bool)
  21. def visible_part_locs(p):
  22. idxs, locs, vis = __expand_parts(p)
  23. return idxs[vis], locs[vis].T
  24. def visible_crops(im, p, ratio=DEFAULT_RATIO, padding_mode="edge"):
  25. assert im.ndim == 3, "Only RGB images are currently supported!"
  26. idxs, locs, vis = __expand_parts(p)
  27. h, w, c = im.shape
  28. crop_h = crop_w = int(np.sqrt(h * w) * ratio)
  29. crops = np.zeros((len(idxs), crop_h, crop_w, c), dtype=im.dtype)
  30. padding = np.array([crop_h, crop_w]) // 2
  31. padded_im = np.pad(im, [padding, padding, [0,0]], mode=padding_mode)
  32. for i, loc, is_vis in zip(idxs, locs, vis):
  33. if not is_vis: continue
  34. x0, y0 = loc - crop_h // 2 + padding
  35. crops[i] = padded_im[y0:y0+crop_h, x0:x0+crop_w]
  36. return crops
  37. def reveal_parts(im, xy, ratio=DEFAULT_RATIO):
  38. h, w, c = im.shape
  39. crop_h = crop_w = int(np.sqrt(h * w) * ratio)
  40. x0y0 = xy - crop_h // 2
  41. res = np.zeros_like(im)
  42. for x0, y0 in x0y0.T:
  43. x1, y1 = x0 + crop_w, y0 + crop_w
  44. x0, y0 = max(x0, 0), max(y0, 0)
  45. res[y0:y0+crop_h, x0:x0+crop_w] = im[y0:y0+crop_h, x0:x0+crop_w]
  46. return res