base.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import numpy as np
  2. import abc
  3. import warnings
  4. import logging
  5. from os.path import join, isfile, isdir
  6. from collections import defaultdict, OrderedDict
  7. from nabirds.utils import read_info_file, feature_file_name
  8. from nabirds.dataset import Dataset
  9. def _parse_index(idx, offset):
  10. if idx.isdigit():
  11. idx = str(int(idx) - offset)
  12. return idx
  13. class BaseAnnotations(abc.ABC):
  14. FEATURE_PHONY = dict(train=["train"], test=["test", "val"])
  15. def __init__(self, root_or_infofile, parts=None, feature_model=None):
  16. super(BaseAnnotations, self).__init__()
  17. self.part_type = parts
  18. self.feature_model = feature_model
  19. if isdir(root_or_infofile):
  20. self.info = None
  21. self.root = root_or_infofile
  22. elif isfile(root_or_infofile):
  23. self.root = self.root_from_infofile(root_or_infofile, parts)
  24. else:
  25. raise ValueError("Root folder or info file does not exist: \"{}\"".format(
  26. root_or_infofile
  27. ))
  28. for fname, attr in self.meta.structure:
  29. self.read_content(fname, attr)
  30. self._load_uuids()
  31. self._load_labels()
  32. self._load_parts()
  33. self._load_split()
  34. @property
  35. def data_root(self):
  36. if self.info is None: return None
  37. return join(self.info.BASE_DIR, self.info.DATA_DIR)
  38. @property
  39. def dataset_info(self):
  40. if self.info is None: return None
  41. if self.part_type is None:
  42. return self.info.DATASETS[self.__class__.name]
  43. else:
  44. return self.info.PARTS[self.part_type]
  45. def root_from_infofile(self, info_file, parts=None):
  46. self.info = read_info_file(info_file)
  47. dataset_info = self.dataset_info
  48. annot_dir = join(self.data_root, dataset_info.folder, dataset_info.annotations)
  49. assert isdir(annot_dir), "Annotation folder does exist! \"{}\"".format(annot_dir)
  50. return annot_dir
  51. def new_dataset(self, subset=None, dataset_cls=Dataset, **kwargs):
  52. if subset is not None:
  53. uuids = getattr(self, "{}_uuids".format(subset))
  54. else:
  55. uuids = self.uuids
  56. kwargs = self.check_parts_and_features(subset, **kwargs)
  57. return dataset_cls(uuids=uuids, annotations=self, **kwargs)
  58. def check_parts_and_features(self, subset, **kwargs):
  59. dataset_info = self.dataset_info
  60. if dataset_info is None:
  61. return kwargs
  62. logging.debug("Dataset info: {}".format(dataset_info))
  63. # TODO: pass all scales
  64. new_opts = {}
  65. if "scales" in dataset_info:
  66. new_opts["ratio"] = dataset_info.scales[0]
  67. if "is_uniform" in dataset_info:
  68. new_opts["uniform_parts"] = dataset_info.is_uniform
  69. if self.part_type is not None:
  70. new_opts["part_rescale_size"] = dataset_info.rescale_size
  71. if None not in [subset, self.feature_model]:
  72. tried = []
  73. model_info = self.info.MODELS[self.feature_model]
  74. for subset_phony in BaseAnnotations.FEATURE_PHONY[subset]:
  75. features = feature_file_name(subset_phony, dataset_info, model_info)
  76. feature_path = join(self.root, "features", features)
  77. if isfile(feature_path): break
  78. tried.append(feature_path)
  79. else:
  80. raise ValueError(
  81. "Could not find any features in \"{}\" for {} subset. Tried features: {}".format(
  82. join(self.root, "features"), subset, tried))
  83. new_opts["features"] = feature_path
  84. new_opts.update(kwargs)
  85. logging.debug("Final kwargs: {}".format(new_opts))
  86. return new_opts
  87. @property
  88. def has_parts(self):
  89. return hasattr(self, "_part_locs") and self._part_locs is not None
  90. @property
  91. @abc.abstractmethod
  92. def meta(self):
  93. pass
  94. def _path(self, file):
  95. return join(self.root, file)
  96. def _open(self, file):
  97. return open(self._path(file))
  98. def read_content(self, file, attr):
  99. content = None
  100. fpath = self._path(file)
  101. if isfile(fpath):
  102. with self._open(file) as f:
  103. content = [line.strip() for line in f if line.strip()]
  104. else:
  105. warnings.warn("File \"{}\" was not found!".format(fpath))
  106. setattr(self, attr, content)
  107. def _load_labels(self):
  108. self.labels = np.array([int(l) for l in self.labels], dtype=np.int32)
  109. def _load_uuids(self):
  110. assert self._images is not None, "Images were not loaded!"
  111. uuid_fnames = [i.split() for i in self._images]
  112. self.uuids, self.images = map(np.array, zip(*uuid_fnames))
  113. self.uuid_to_idx = {uuid: i for i, uuid in enumerate(self.uuids)}
  114. def _load_parts(self, idx_offset=0):
  115. assert self.has_parts, "Part locations were not loaded!"
  116. # this part is quite slow... TODO: some runtime improvements?
  117. idx_to_parts = defaultdict(list)
  118. for content in [i.split() for i in self._part_locs]:
  119. uuid = _parse_index(content[0], idx_offset)
  120. idx_to_parts[uuid].append([float(c) for c in content[1:]])
  121. idx_to_parts = dict(idx_to_parts)
  122. self.part_locs = np.stack([
  123. idx_to_parts[uuid] for uuid in self.uuids]).astype(int)
  124. if hasattr(self, "_part_names") and self._part_names is not None:
  125. self._load_part_names()
  126. def _load_part_names(self):
  127. self.part_names = OrderedDict()
  128. self.part_name_list = []
  129. for line in self._part_names:
  130. part_idx, _, name = line.partition(" ")
  131. self.part_names[int(part_idx)] = name
  132. self.part_name_list.append(name)
  133. def _load_split(self):
  134. assert self._split is not None, "Train-test split was not loaded!"
  135. uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
  136. self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
  137. self.test_split = np.logical_not(self.train_split)
  138. def image_path(self, image):
  139. return join(self.root, self.meta.images_folder, image)
  140. def image(self, uuid):
  141. fname = self.images[self.uuid_to_idx[uuid]]
  142. return self.image_path(fname)
  143. def label(self, uuid):
  144. return self.labels[self.uuid_to_idx[uuid]].copy()
  145. def parts(self, uuid):
  146. return self.part_locs[self.uuid_to_idx[uuid]].copy()
  147. def _uuids(self, split):
  148. return self.uuids[split]
  149. @property
  150. def train_uuids(self):
  151. return self._uuids(self.train_split)
  152. @property
  153. def test_uuids(self):
  154. return self._uuids(self.test_split)