base.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import numpy as np
  2. import abc
  3. import warnings
  4. import logging
  5. from os.path import join, isfile, isdir
  6. from collections import defaultdict, OrderedDict
  7. try:
  8. from yaml import CLoader as Loader, CDumper as Dumper
  9. except ImportError:
  10. from yaml import Loader, Dumper
  11. import yaml
  12. from nabirds.utils import attr_dict
  13. from nabirds.dataset import Dataset
  14. class BaseAnnotations(abc.ABC):
  15. FEATURE_PHONY = dict(train=["train"], test=["test", "val"])
  16. def __init__(self, root_or_infofile, parts=None, feature_model=None):
  17. super(BaseAnnotations, self).__init__()
  18. self.part_type = parts
  19. self.feature_model = feature_model
  20. if isdir(root_or_infofile):
  21. self.info = None
  22. self.root = root_or_infofile
  23. elif isfile(root_or_infofile):
  24. self.root = self.root_from_infofile(root_or_infofile, parts, feature_model)
  25. else:
  26. raise ValueError("Root folder or info file does not exist: \"{}\"".format(
  27. root_or_infofile
  28. ))
  29. for fname, attr in self.meta.structure:
  30. self.read_content(fname, attr)
  31. self.labels = np.array([int(l) for l in self.labels], dtype=np.int32)
  32. self._load_uuids()
  33. self._load_parts()
  34. self._load_split()
  35. @property
  36. def data_root(self):
  37. if self.info is None: return None
  38. return join(self.info.BASE_DIR, self.info.DATA_DIR)
  39. @property
  40. def dataset_info(self):
  41. if self.info is None: return None
  42. if self.part_type is None:
  43. return self.info.DATASETS[self.__class__.name]
  44. else:
  45. return self.info.PARTS[self.part_type]
  46. def root_from_infofile(self, info_file, parts=None, feature_model=None):
  47. with open(info_file) as f:
  48. self.info = attr_dict(yaml.load(f, Loader=Loader))
  49. dataset_info = self.dataset_info
  50. annot_dir = join(self.data_root, dataset_info.folder, dataset_info.annotations)
  51. assert isdir(annot_dir), "Annotation folder does exist! \"{}\"".format(annot_dir)
  52. return annot_dir
  53. def new_dataset(self, subset=None, dataset_cls=Dataset, **kwargs):
  54. if subset is not None:
  55. uuids = getattr(self, "{}_uuids".format(subset))
  56. else:
  57. uuids = self.uuids
  58. kwargs = self.check_parts_and_features(subset, **kwargs)
  59. return dataset_cls(uuids=uuids, annotations=self, **kwargs)
  60. def check_parts_and_features(self, subset, **kwargs):
  61. dataset_info = self.dataset_info
  62. if dataset_info is None:
  63. return kwargs
  64. # TODO: pass all scales
  65. new_opts = {
  66. "ratio": dataset_info.scales[0],
  67. "uniform_parts": dataset_info.is_uniform
  68. }
  69. if self.part_type is not None:
  70. new_opts["part_rescale_size"] = dataset_info.rescale_size
  71. if None not in [subset, self.feature_model]:
  72. tried = []
  73. for subset_phony in BaseAnnotations.FEATURE_PHONY[subset]:
  74. features = "{subset}_{suffix}.{model}.npz".format(
  75. subset=subset_phony,
  76. suffix=dataset_info.feature_suffix,
  77. model=self.feature_model)
  78. feature_path = join(self.root, "features", features)
  79. if isfile(feature_path): break
  80. tried.append(feature_path)
  81. else:
  82. raise ValueError(
  83. "Could not find any features in \"{}\" for {} subset. Tried features: {}".format(
  84. join(self.root, "features"), subset, tried))
  85. new_opts["features"] = feature_path
  86. new_opts.update(kwargs)
  87. logging.debug(new_opts)
  88. return new_opts
  89. @property
  90. @abc.abstractmethod
  91. def meta(self):
  92. pass
  93. def _path(self, file):
  94. return join(self.root, file)
  95. def _open(self, file):
  96. return open(self._path(file))
  97. def read_content(self, file, attr):
  98. content = None
  99. fpath = self._path(file)
  100. if isfile(fpath):
  101. with self._open(file) as f:
  102. content = [line.strip() for line in f if line.strip()]
  103. else:
  104. warnings.warn("File \"{}\" was not found!".format(fpath))
  105. setattr(self, attr, content)
  106. def _load_uuids(self):
  107. assert self._images is not None, "Images were not loaded!"
  108. uuid_fnames = [i.split() for i in self._images]
  109. self.uuids, self.images = map(np.array, zip(*uuid_fnames))
  110. self.uuid_to_idx = {uuid: i for i, uuid in enumerate(self.uuids)}
  111. def _load_parts(self):
  112. assert self._part_locs is not None, "Part locations were not loaded!"
  113. # this part is quite slow... TODO: some runtime improvements?
  114. uuid_to_parts = defaultdict(list)
  115. for content in [i.split() for i in self._part_locs]:
  116. uuid_to_parts[content[0]].append([float(i) for i in content[1:]])
  117. self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids]).astype(int)
  118. if hasattr(self, "_part_names") and self._part_names is not None:
  119. self._load_part_names()
  120. def _load_part_names(self):
  121. self.part_names = OrderedDict()
  122. self.part_name_list = []
  123. for line in self._part_names:
  124. part_idx, _, name = line.partition(" ")
  125. self.part_names[int(part_idx)] = name
  126. self.part_name_list.append(name)
  127. def _load_split(self):
  128. assert self._split is not None, "Train-test split was not loaded!"
  129. uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
  130. self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
  131. self.test_split = np.logical_not(self.train_split)
  132. def image_path(self, image):
  133. return join(self.root, self.meta.images_folder, image)
  134. def image(self, uuid):
  135. fname = self.images[self.uuid_to_idx[uuid]]
  136. return self.image_path(fname)
  137. def label(self, uuid):
  138. return self.labels[self.uuid_to_idx[uuid]].copy()
  139. def parts(self, uuid):
  140. return self.part_locs[self.uuid_to_idx[uuid]].copy()
  141. def _uuids(self, split):
  142. return self.uuids[split]
  143. @property
  144. def train_uuids(self):
  145. return self._uuids(self.train_split)
  146. @property
  147. def test_uuids(self):
  148. return self._uuids(self.test_split)