reading.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import numpy as np
  2. from os.path import join
  3. from . import BaseMixin
  4. from ..image import ImageWrapper
  5. class AnnotationsReadMixin(BaseMixin):
  6. def __init__(self, uuids, annotations, mode="RGB"):
  7. super(AnnotationsReadMixin, self).__init__()
  8. self.uuids = uuids
  9. self._annot = annotations
  10. self.mode = mode
  11. def __len__(self):
  12. return len(self.uuids)
  13. def _get(self, method, i):
  14. return getattr(self._annot, method)(self.uuids[i])
  15. def get_example(self, i):
  16. res = super(AnnotationsReadMixin, self).get_example(i)
  17. # if the super class returns something, then the class inheritance is wrong
  18. assert res is None, "AnnotationsReadMixin should be the last class in the hierarchy!"
  19. methods = ["image", "parts", "label"]
  20. im_path, parts, label = [self._get(m, i) for m in methods]
  21. return ImageWrapper(im_path, int(label), parts, mode=self.mode)
  22. @property
  23. def labels(self):
  24. return np.array([self._get("label", i) for i in range(len(self))])
  25. class ImageListReadingMixin(BaseMixin):
  26. def __init__(self, pairs, root="."):
  27. super(ImageListReadingMixin, self).__init__()
  28. with open(pairs) as f:
  29. self._pairs = [line.strip().split() for line in f]
  30. assert all([len(pair) == 2 for pair in self._pairs]), \
  31. "Invalid format of the pairs file!"
  32. self._root = root
  33. def __len__(self):
  34. return len(self._pairs)
  35. def get_example(self, i):
  36. im_file, label = self._pairs[i]
  37. im_path = join(self._root, im_file)
  38. return ImageWrapper(im_path, int(label))
  39. @property
  40. def labels(self):
  41. return np.array([label for (_, label) in self._pairs])