__init__.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import numpy as np
  2. from abc import ABC, abstractmethod
  3. from imageio import imread
  4. from . import utils
  5. class BaseMixin(ABC):
  6. @abstractmethod
  7. def get_example(self, i):
  8. pass
  9. def __getitem__(self, i):
  10. return self.get_example(i)
  11. class ImageReadMixin(BaseMixin):
  12. def __init__(self, uuids, annotations, mode="RGB"):
  13. super(BaseMixin, self).__init__()
  14. self.uuids = uuids
  15. self._annot = annotations
  16. self.mode = mode
  17. def __len__(self):
  18. return len(self.uuids)
  19. def _get(self, method, i):
  20. return getattr(self._annot, method)(self.uuids[i])
  21. def get_example(self, i):
  22. res = super(ImageReadMixin, self).get_example(i)
  23. # if the super class returns something, then the class inheritance is wrong
  24. assert res is None, "ImageReadMixin should be the last class in the hierarchy!"
  25. methods = ["image", "parts", "label"]
  26. im_path, parts, label = [self._get(m, i) for m in methods]
  27. im = imread(im_path, pilmode=self.mode)
  28. return im, parts, label
  29. class BBCropMixin(BaseMixin):
  30. def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
  31. super(BBCropMixin, self).__init__(*args, **kwargs)
  32. self.crop_to_bb = crop_to_bb
  33. self.crop_uniform = crop_uniform
  34. def bounding_box(self, i):
  35. bbox = self._get("bounding_box", i)
  36. x,y,w,h = [bbox[attr] for attr in "xywh"]
  37. if self.crop_uniform:
  38. x0 = x + w//2
  39. y0 = y + h//2
  40. crop_size = max(w//2, h//2)
  41. x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
  42. w = h = crop_size * 2
  43. return x,y,w,h
  44. def get_example(self, i):
  45. im, parts, label = super(BBCropMixin, self).get_example(i)
  46. if self.crop_to_bb:
  47. x,y,w,h = self.bounding_box(i)
  48. im = im[y:y+h, x:x+w]
  49. parts[:, 1] -= x
  50. parts[:, 2] -= y
  51. return im, parts, label
  52. class UniformPartMixin(BaseMixin):
  53. def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs):
  54. super(UniformPartMixin, self).__init__(*args, **kwargs)
  55. self.uniform_parts = uniform_parts
  56. self.ratio = ratio
  57. def get_example(self, i):
  58. im, parts, label = super(UniformPartMixin, self).get_example(i)
  59. if self.uniform_parts:
  60. parts = utils.uniform_parts(im, ratio=self.ratio)
  61. return im, parts, label
  62. class RandomBlackOutMixin(BaseMixin):
  63. def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
  64. super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
  65. self.rnd = np.random.RandomState(seed)
  66. self.rnd_select = rnd_select
  67. self.n_parts = n_parts
  68. def get_example(self, i):
  69. im, parts, lab = super(RandomBlackOutMixin, self).get_example(i)
  70. if self.rnd_select:
  71. idxs, xy = utils.visible_part_locs(parts)
  72. rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts)
  73. parts[:, -1] = 0
  74. parts[rnd_idxs, -1] = 1
  75. return im, parts, lab
  76. class Dataset(RandomBlackOutMixin, UniformPartMixin, BBCropMixin, ImageReadMixin):
  77. pass