base.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import numpy as np
  2. import abc
  3. import warnings
  4. import logging
  5. from os.path import join, isfile, isdir
  6. from collections import defaultdict, OrderedDict
  7. from nabirds.utils import read_info_file, feature_file_name
  8. from nabirds.dataset import Dataset
  9. class BaseAnnotations(abc.ABC):
  10. FEATURE_PHONY = dict(train=["train"], test=["test", "val"])
  11. def __init__(self, root_or_infofile, parts=None, feature_model=None):
  12. super(BaseAnnotations, self).__init__()
  13. self.part_type = parts
  14. self.feature_model = feature_model
  15. if isdir(root_or_infofile):
  16. self.info = None
  17. self.root = root_or_infofile
  18. elif isfile(root_or_infofile):
  19. self.root = self.root_from_infofile(root_or_infofile, parts)
  20. else:
  21. raise ValueError("Root folder or info file does not exist: \"{}\"".format(
  22. root_or_infofile
  23. ))
  24. for fname, attr in self.meta.structure:
  25. self.read_content(fname, attr)
  26. self._load_uuids()
  27. self._load_labels()
  28. self._load_parts()
  29. self._load_split()
  30. @property
  31. def data_root(self):
  32. if self.info is None: return None
  33. return join(self.info.BASE_DIR, self.info.DATA_DIR)
  34. @property
  35. def dataset_info(self):
  36. if self.info is None: return None
  37. if self.part_type is None:
  38. return self.info.DATASETS[self.__class__.name]
  39. else:
  40. return self.info.PARTS[self.part_type]
  41. def root_from_infofile(self, info_file, parts=None):
  42. self.info = read_info_file(info_file)
  43. dataset_info = self.dataset_info
  44. annot_dir = join(self.data_root, dataset_info.folder, dataset_info.annotations)
  45. assert isdir(annot_dir), "Annotation folder does exist! \"{}\"".format(annot_dir)
  46. return annot_dir
  47. def new_dataset(self, subset=None, dataset_cls=Dataset, **kwargs):
  48. if subset is not None:
  49. uuids = getattr(self, "{}_uuids".format(subset))
  50. else:
  51. uuids = self.uuids
  52. kwargs = self.check_parts_and_features(subset, **kwargs)
  53. return dataset_cls(uuids=uuids, annotations=self, **kwargs)
  54. def check_parts_and_features(self, subset, **kwargs):
  55. dataset_info = self.dataset_info
  56. if dataset_info is None:
  57. return kwargs
  58. logging.debug("Dataset info: {}".format(dataset_info))
  59. # TODO: pass all scales
  60. new_opts = {}
  61. if "scales" in dataset_info:
  62. new_opts["ratio"] = dataset_info.scales[0]
  63. if "is_uniform" in dataset_info:
  64. new_opts["uniform_parts"] = dataset_info.is_uniform
  65. if self.part_type is not None:
  66. new_opts["part_rescale_size"] = dataset_info.rescale_size
  67. if None not in [subset, self.feature_model]:
  68. tried = []
  69. model_info = self.info.MODELS[self.feature_model]
  70. for subset_phony in BaseAnnotations.FEATURE_PHONY[subset]:
  71. features = feature_file_name(subset_phony, dataset_info, model_info)
  72. feature_path = join(self.root, "features", features)
  73. if isfile(feature_path): break
  74. tried.append(feature_path)
  75. else:
  76. raise ValueError(
  77. "Could not find any features in \"{}\" for {} subset. Tried features: {}".format(
  78. join(self.root, "features"), subset, tried))
  79. new_opts["features"] = feature_path
  80. new_opts.update(kwargs)
  81. logging.debug("Final kwargs: {}".format(new_opts))
  82. return new_opts
  83. @property
  84. @abc.abstractmethod
  85. def meta(self):
  86. pass
  87. def _path(self, file):
  88. return join(self.root, file)
  89. def _open(self, file):
  90. return open(self._path(file))
  91. def read_content(self, file, attr):
  92. content = None
  93. fpath = self._path(file)
  94. if isfile(fpath):
  95. with self._open(file) as f:
  96. content = [line.strip() for line in f if line.strip()]
  97. else:
  98. warnings.warn("File \"{}\" was not found!".format(fpath))
  99. setattr(self, attr, content)
  100. def _load_labels(self):
  101. self.labels = np.array([int(l) for l in self.labels], dtype=np.int32)
  102. def _load_uuids(self):
  103. assert self._images is not None, "Images were not loaded!"
  104. uuid_fnames = [i.split() for i in self._images]
  105. self.uuids, self.images = map(np.array, zip(*uuid_fnames))
  106. self.uuid_to_idx = {uuid: i for i, uuid in enumerate(self.uuids)}
  107. def _load_parts(self):
  108. assert self._part_locs is not None, "Part locations were not loaded!"
  109. # this part is quite slow... TODO: some runtime improvements?
  110. uuid_to_parts = defaultdict(list)
  111. for content in [i.split() for i in self._part_locs]:
  112. uuid_to_parts[content[0]].append([float(i) for i in content[1:]])
  113. self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids]).astype(int)
  114. if hasattr(self, "_part_names") and self._part_names is not None:
  115. self._load_part_names()
  116. def _load_part_names(self):
  117. self.part_names = OrderedDict()
  118. self.part_name_list = []
  119. for line in self._part_names:
  120. part_idx, _, name = line.partition(" ")
  121. self.part_names[int(part_idx)] = name
  122. self.part_name_list.append(name)
  123. def _load_split(self):
  124. assert self._split is not None, "Train-test split was not loaded!"
  125. uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
  126. self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
  127. self.test_split = np.logical_not(self.train_split)
  128. def image_path(self, image):
  129. return join(self.root, self.meta.images_folder, image)
  130. def image(self, uuid):
  131. fname = self.images[self.uuid_to_idx[uuid]]
  132. return self.image_path(fname)
  133. def label(self, uuid):
  134. return self.labels[self.uuid_to_idx[uuid]].copy()
  135. def parts(self, uuid):
  136. return self.part_locs[self.uuid_to_idx[uuid]].copy()
  137. def _uuids(self, split):
  138. return self.uuids[split]
  139. @property
  140. def train_uuids(self):
  141. return self._uuids(self.train_split)
  142. @property
  143. def test_uuids(self):
  144. return self._uuids(self.test_split)