dataset.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from imageio import imread
  2. import numpy as np
  3. class Dataset(object):
  4. def __init__(self, uuids, annotations):
  5. super(Dataset, self).__init__()
  6. self.uuids = uuids
  7. self._annot = annotations
  8. def __len__(self):
  9. return len(self.uuids)
  10. def _get(self, method, i):
  11. return getattr(self._annot, method)(self.uuids[i])
  12. def get_example(self, i, mode="RGB"):
  13. methods = ["image", "parts", "label"]
  14. im_path, parts, label = [self._get(m, i) for m in methods]
  15. return imread(im_path, pilmode=mode), parts, label
  16. __getitem__ = get_example
  17. # some convention functions
  18. def __expand_parts(p):
  19. return p[:, 0], p[:, 1:3], p[:, 3].astype(bool)
  20. def visible_part_locs(p):
  21. idxs, locs, vis = __expand_parts(p)
  22. return idxs[vis], locs[vis].T
  23. def visible_crops(im, p, ratio=np.sqrt(49 / 400), padding_mode="edge"):
  24. assert im.ndim == 3, "Only RGB images are currently supported!"
  25. idxs, locs, vis = __expand_parts(p)
  26. h, w, c = im.shape
  27. crop_h = crop_w = int(np.sqrt(h*w) * ratio)
  28. crops = np.zeros((len(idxs), crop_h, crop_w, c), dtype=im.dtype)
  29. padding = np.array([crop_h, crop_w]) // 2
  30. padded_im = np.pad(im, [padding, padding, [0,0]], mode=padding_mode)
  31. for i, loc, is_vis in zip(*__expand_parts(p)):
  32. if not is_vis: continue
  33. x0, y0 = loc - crop_h // 2 + padding
  34. crops[i] = padded_im[y0:y0+crop_h, x0:x0+crop_w]
  35. return crops