reading.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import numpy as np
  2. from os.path import join
  3. from . import BaseMixin
  4. from ..image import ImageWrapper
  5. class AnnotationsReadMixin(BaseMixin):
  6. def __init__(self, uuids, annotations, part_rescale_size=None, mode="RGB"):
  7. super(AnnotationsReadMixin, self).__init__()
  8. self.uuids = uuids
  9. self._annot = annotations
  10. self.mode = mode
  11. self.part_rescale_size = part_rescale_size
  12. def __len__(self):
  13. return len(self.uuids)
  14. def _get(self, method, i):
  15. return getattr(self._annot, method)(self.uuids[i])
  16. def get_example(self, i):
  17. # res = super(AnnotationsReadMixin, self).get_example(i)
  18. # # if the super class returns something, then the class inheritance is wrong
  19. # assert res is None, "AnnotationsReadMixin should be the last class in the hierarchy!"
  20. methods = ["image", "parts", "label"]
  21. im_path, parts, label = [self._get(m, i) for m in methods]
  22. return ImageWrapper(im_path, int(label), parts, mode=self.mode, part_rescale_size=self.part_rescale_size)
  23. @property
  24. def n_parts(self):
  25. return self._annot.part_locs.shape[1]
  26. @property
  27. def labels(self):
  28. return np.array([self._get("label", i) for i in range(len(self))])
  29. class ImageListReadingMixin(BaseMixin):
  30. def __init__(self, pairs, root="."):
  31. super(ImageListReadingMixin, self).__init__()
  32. with open(pairs) as f:
  33. self._pairs = [line.strip().split() for line in f]
  34. assert all([len(pair) == 2 for pair in self._pairs]), \
  35. "Invalid format of the pairs file!"
  36. self._root = root
  37. def __len__(self):
  38. return len(self._pairs)
  39. def get_example(self, i):
  40. im_file, label = self._pairs[i]
  41. im_path = join(self._root, im_file)
  42. return ImageWrapper(im_path, int(label))
  43. @property
  44. def labels(self):
  45. return np.array([label for (_, label) in self._pairs])