base.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. from os.path import join, isfile, isdir
  2. import numpy as np
  3. from collections import defaultdict, OrderedDict
  4. import abc
  5. import warnings
  6. try:
  7. from yaml import CLoader as Loader, CDumper as Dumper
  8. except ImportError:
  9. from yaml import Loader, Dumper
  10. import yaml
  11. import simplejson as json
  12. from nabirds.utils import attr_dict
  13. from nabirds.dataset import Dataset
  14. class BaseAnnotations(abc.ABC):
  15. def __init__(self, root_or_infofile, parts=None, feature_model=None):
  16. super(BaseAnnotations, self).__init__()
  17. self.part_type = parts
  18. self.feature_model = feature_model
  19. if isdir(root_or_infofile):
  20. self.info = None
  21. self.root = root_or_infofile
  22. elif isfile(root_or_infofile):
  23. self.root = self.root_from_infofile(root_or_infofile, parts, feature_model)
  24. else:
  25. raise ValueError("Root folder or info file does not exist: \"{}\"".format(
  26. root_or_infofile
  27. ))
  28. for fname, attr in self.meta.structure:
  29. self.read_content(fname, attr)
  30. self.labels = np.array([int(l) for l in self.labels], dtype=np.int32)
  31. self._load_uuids()
  32. self._load_parts()
  33. self._load_split()
  34. @property
  35. def data_root(self):
  36. if self.info is None: return None
  37. return join(self.info.BASE_DIR, self.info.DATA_DIR)
  38. @property
  39. def dataset_info(self):
  40. if self.info is None: return None
  41. if self.part_type is None:
  42. return self.info.DATASETS[self.__class__.name]
  43. else:
  44. return self.info.PARTS[self.part_type]
  45. def root_from_infofile(self, info_file, parts=None, feature_model=None):
  46. with open(info_file) as f:
  47. self.info = attr_dict(yaml.load(f, Loader=Loader))
  48. dataset_info = self.dataset_info
  49. # print(json.dumps(dataset_info, indent=2))
  50. annot_dir = join(self.data_root, dataset_info.folder, dataset_info.annotations)
  51. assert isdir(annot_dir), "Annotation folder does exist! \"{}\"".format(annot_dir)
  52. return annot_dir
  53. def new_dataset(self, subset=None, dataset_cls=Dataset, **kwargs):
  54. if subset is not None:
  55. uuids = getattr(self, "{}_uuids".format(subset))
  56. else:
  57. uuids = self.uuids
  58. kwargs = self.check_parts_and_features(subset, **kwargs)
  59. return dataset_cls(uuids=uuids, annotations=self, **kwargs)
  60. def check_parts_and_features(self, subset, **kwargs):
  61. dataset_info = self.dataset_info
  62. # TODO: pass all scales
  63. new_opts = {
  64. "ratio": dataset_info.scales[0],
  65. "uniform_parts": dataset_info.is_uniform
  66. }
  67. if self.part_type is not None:
  68. new_opts["part_rescale_size"] = dataset_info.rescale_size
  69. if None not in [subset, self.feature_model]:
  70. features = "{subset}_{suffix}.{model}.npz".format(
  71. subset=subset,
  72. suffix=dataset_info.feature_suffix,
  73. model=self.feature_model)
  74. feature_path = join(self.root, "features", features)
  75. assert isfile(feature_path), \
  76. "Features do not exist: \"{}\"".format(feature_path)
  77. new_opts["features"] = feature_path
  78. new_opts.update(kwargs)
  79. print(new_opts)
  80. return new_opts
  81. @property
  82. @abc.abstractmethod
  83. def meta(self):
  84. pass
  85. def _path(self, file):
  86. return join(self.root, file)
  87. def _open(self, file):
  88. return open(self._path(file))
  89. def read_content(self, file, attr):
  90. content = None
  91. fpath = self._path(file)
  92. if isfile(fpath):
  93. with self._open(file) as f:
  94. content = [line.strip() for line in f if line.strip()]
  95. else:
  96. warnings.warn("File \"{}\" was not found!".format(fpath))
  97. setattr(self, attr, content)
  98. def _load_uuids(self):
  99. assert self._images is not None, "Images were not loaded!"
  100. uuid_fnames = [i.split() for i in self._images]
  101. self.uuids, self.images = map(np.array, zip(*uuid_fnames))
  102. self.uuid_to_idx = {uuid: i for i, uuid in enumerate(self.uuids)}
  103. def _load_parts(self):
  104. assert self._part_locs is not None, "Part locations were not loaded!"
  105. # this part is quite slow... TODO: some runtime improvements?
  106. uuid_to_parts = defaultdict(list)
  107. for content in [i.split() for i in self._part_locs]:
  108. uuid_to_parts[content[0]].append([float(i) for i in content[1:]])
  109. self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids]).astype(int)
  110. if hasattr(self, "_part_names") and self._part_names is not None:
  111. self._load_part_names()
  112. def _load_part_names(self):
  113. self.part_names = OrderedDict()
  114. self.part_name_list = []
  115. for line in self._part_names:
  116. part_idx, _, name = line.partition(" ")
  117. self.part_names[int(part_idx)] = name
  118. self.part_name_list.append(name)
  119. def _load_split(self):
  120. assert self._split is not None, "Train-test split was not loaded!"
  121. uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
  122. self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
  123. self.test_split = np.logical_not(self.train_split)
  124. def image_path(self, image):
  125. return join(self.root, self.meta.images_folder, image)
  126. def image(self, uuid):
  127. fname = self.images[self.uuid_to_idx[uuid]]
  128. return self.image_path(fname)
  129. def label(self, uuid):
  130. return self.labels[self.uuid_to_idx[uuid]].copy()
  131. def parts(self, uuid):
  132. return self.part_locs[self.uuid_to_idx[uuid]].copy()
  133. def _uuids(self, split):
  134. return self.uuids[split]
  135. @property
  136. def train_uuids(self):
  137. return self._uuids(self.train_split)
  138. @property
  139. def test_uuids(self):
  140. return self._uuids(self.test_split)