123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import numpy as np
- from os.path import join
- from . import BaseMixin
- from ..image import ImageWrapper
- class AnnotationsReadMixin(BaseMixin):
- def __init__(self, uuids, annotations, part_rescale_size=None, mode="RGB"):
- super(AnnotationsReadMixin, self).__init__()
- self.uuids = uuids
- self._annot = annotations
- self.mode = mode
- self.part_rescale_size = part_rescale_size
- def __len__(self):
- return len(self.uuids)
- def _get(self, method, i):
- return getattr(self._annot, method)(self.uuids[i])
- def get_example(self, i):
- # res = super(AnnotationsReadMixin, self).get_example(i)
- # # if the super class returns something, then the class inheritance is wrong
- # assert res is None, "AnnotationsReadMixin should be the last class in the hierarchy!"
- methods = ["image", "parts", "label"]
- im_path, parts, label = [self._get(m, i) for m in methods]
- return ImageWrapper(im_path, int(label), parts, mode=self.mode, part_rescale_size=self.part_rescale_size)
- @property
- def n_parts(self):
- return self._annot.part_locs.shape[1]
- @property
- def labels(self):
- return np.array([self._get("label", i) for i in range(len(self))])
- class ImageListReadingMixin(BaseMixin):
- def __init__(self, pairs, root="."):
- super(ImageListReadingMixin, self).__init__()
- with open(pairs) as f:
- self._pairs = [line.strip().split() for line in f]
- assert all([len(pair) == 2 for pair in self._pairs]), \
- "Invalid format of the pairs file!"
- self._root = root
- def __len__(self):
- return len(self._pairs)
- def get_example(self, i):
- im_file, label = self._pairs[i]
- im_path = join(self._root, im_file)
- return ImageWrapper(im_path, int(label))
- @property
- def labels(self):
- return np.array([label for (_, label) in self._pairs])
|