reading.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from imageio import imread
  2. from os.path import join, isfile
  3. from .base import BaseMixin
  4. class AnnotationsReadMixin(BaseMixin):
  5. def __init__(self, uuids, annotations, mode="RGB"):
  6. super(AnnotationsReadMixin, self).__init__()
  7. self.uuids = uuids
  8. self._annot = annotations
  9. self.mode = mode
  10. def __len__(self):
  11. return len(self.uuids)
  12. def _get(self, method, i):
  13. return getattr(self._annot, method)(self.uuids[i])
  14. def get_example(self, i):
  15. res = super(AnnotationsReadMixin, self).get_example(i)
  16. # if the super class returns something, then the class inheritance is wrong
  17. assert res is None, "AnnotationsReadMixin should be the last class in the hierarchy!"
  18. methods = ["image", "parts", "label"]
  19. im_path, parts, label = [self._get(m, i) for m in methods]
  20. im = imread(im_path, pilmode=self.mode)
  21. return im, parts, label
  22. class ImageListReadingMixin(BaseMixin):
  23. def __init__(self, pairs, root="."):
  24. super(ImageListReadingMixin, self).__init__()
  25. with open(pairs) as f:
  26. self._pairs = [line.strip().split() for line in f]
  27. assert all([len(pair) == 2 for pair in self._pairs]), \
  28. "Invalid format of the pairs file!"
  29. self._root = root
  30. def __len__(self):
  31. return len(self._pairs)
  32. def get_example(self, i):
  33. im_file, label = self._pairs[i]
  34. im_path = join(self._root, im_file)
  35. assert isfile(im_path), "Image \"{}\" does not exist!".format(im_path)
  36. im = imread(im_path, pilmode="RGB")
  37. return im, int(label)