base.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import numpy as np
  2. import abc
  3. import warnings
  4. import logging
  5. from os.path import join, isfile, isdir
  6. from collections import defaultdict, OrderedDict
  7. try:
  8. from yaml import CLoader as Loader, CDumper as Dumper
  9. except ImportError:
  10. from yaml import Loader, Dumper
  11. import yaml
  12. from nabirds.utils import attr_dict
  13. from nabirds.dataset import Dataset
  14. class BaseAnnotations(abc.ABC):
  15. FEATURE_PHONY = dict(train=["train"], test=["test", "val"])
  16. def __init__(self, root_or_infofile, parts=None, feature_model=None):
  17. super(BaseAnnotations, self).__init__()
  18. self.part_type = parts
  19. self.feature_model = feature_model
  20. if isdir(root_or_infofile):
  21. self.info = None
  22. self.root = root_or_infofile
  23. elif isfile(root_or_infofile):
  24. self.root = self.root_from_infofile(root_or_infofile, parts, feature_model)
  25. else:
  26. raise ValueError("Root folder or info file does not exist: \"{}\"".format(
  27. root_or_infofile
  28. ))
  29. for fname, attr in self.meta.structure:
  30. self.read_content(fname, attr)
  31. self.labels = np.array([int(l) for l in self.labels], dtype=np.int32)
  32. self._load_uuids()
  33. self._load_parts()
  34. self._load_split()
  35. @property
  36. def data_root(self):
  37. if self.info is None: return None
  38. return join(self.info.BASE_DIR, self.info.DATA_DIR)
  39. @property
  40. def dataset_info(self):
  41. if self.info is None: return None
  42. if self.part_type is None:
  43. return self.info.DATASETS[self.__class__.name]
  44. else:
  45. return self.info.PARTS[self.part_type]
  46. def root_from_infofile(self, info_file, parts=None, feature_model=None):
  47. with open(info_file) as f:
  48. self.info = attr_dict(yaml.load(f, Loader=Loader))
  49. dataset_info = self.dataset_info
  50. annot_dir = join(self.data_root, dataset_info.folder, dataset_info.annotations)
  51. assert isdir(annot_dir), "Annotation folder does exist! \"{}\"".format(annot_dir)
  52. return annot_dir
  53. def new_dataset(self, subset=None, dataset_cls=Dataset, **kwargs):
  54. if subset is not None:
  55. uuids = getattr(self, "{}_uuids".format(subset))
  56. else:
  57. uuids = self.uuids
  58. kwargs = self.check_parts_and_features(subset, **kwargs)
  59. return dataset_cls(uuids=uuids, annotations=self, **kwargs)
  60. def check_parts_and_features(self, subset, **kwargs):
  61. dataset_info = self.dataset_info
  62. if dataset_info is None:
  63. return kwargs
  64. logging.debug("Dataset info: {}".format(dataset_info))
  65. # TODO: pass all scales
  66. new_opts = {}
  67. if "scales" in dataset_info:
  68. new_opts["ratio"] = dataset_info.scales[0]
  69. if "is_uniform" in dataset_info:
  70. new_opts["uniform_parts"] = dataset_info.is_uniform
  71. if self.part_type is not None:
  72. new_opts["part_rescale_size"] = dataset_info.rescale_size
  73. if None not in [subset, self.feature_model]:
  74. tried = []
  75. for subset_phony in BaseAnnotations.FEATURE_PHONY[subset]:
  76. features = "{subset}_{suffix}.{model}.npz".format(
  77. subset=subset_phony,
  78. suffix=dataset_info.feature_suffix,
  79. model=self.feature_model)
  80. feature_path = join(self.root, "features", features)
  81. if isfile(feature_path): break
  82. tried.append(feature_path)
  83. else:
  84. raise ValueError(
  85. "Could not find any features in \"{}\" for {} subset. Tried features: {}".format(
  86. join(self.root, "features"), subset, tried))
  87. new_opts["features"] = feature_path
  88. new_opts.update(kwargs)
  89. logging.debug("Final kwargs: {}".format(new_opts))
  90. return new_opts
  91. @property
  92. @abc.abstractmethod
  93. def meta(self):
  94. pass
  95. def _path(self, file):
  96. return join(self.root, file)
  97. def _open(self, file):
  98. return open(self._path(file))
  99. def read_content(self, file, attr):
  100. content = None
  101. fpath = self._path(file)
  102. if isfile(fpath):
  103. with self._open(file) as f:
  104. content = [line.strip() for line in f if line.strip()]
  105. else:
  106. warnings.warn("File \"{}\" was not found!".format(fpath))
  107. setattr(self, attr, content)
  108. def _load_uuids(self):
  109. assert self._images is not None, "Images were not loaded!"
  110. uuid_fnames = [i.split() for i in self._images]
  111. self.uuids, self.images = map(np.array, zip(*uuid_fnames))
  112. self.uuid_to_idx = {uuid: i for i, uuid in enumerate(self.uuids)}
  113. def _load_parts(self):
  114. assert self._part_locs is not None, "Part locations were not loaded!"
  115. # this part is quite slow... TODO: some runtime improvements?
  116. uuid_to_parts = defaultdict(list)
  117. for content in [i.split() for i in self._part_locs]:
  118. uuid_to_parts[content[0]].append([float(i) for i in content[1:]])
  119. self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids]).astype(int)
  120. if hasattr(self, "_part_names") and self._part_names is not None:
  121. self._load_part_names()
  122. def _load_part_names(self):
  123. self.part_names = OrderedDict()
  124. self.part_name_list = []
  125. for line in self._part_names:
  126. part_idx, _, name = line.partition(" ")
  127. self.part_names[int(part_idx)] = name
  128. self.part_name_list.append(name)
  129. def _load_split(self):
  130. assert self._split is not None, "Train-test split was not loaded!"
  131. uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
  132. self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
  133. self.test_split = np.logical_not(self.train_split)
  134. def image_path(self, image):
  135. return join(self.root, self.meta.images_folder, image)
  136. def image(self, uuid):
  137. fname = self.images[self.uuid_to_idx[uuid]]
  138. return self.image_path(fname)
  139. def label(self, uuid):
  140. return self.labels[self.uuid_to_idx[uuid]].copy()
  141. def parts(self, uuid):
  142. return self.part_locs[self.uuid_to_idx[uuid]].copy()
  143. def _uuids(self, split):
  144. return self.uuids[split]
  145. @property
  146. def train_uuids(self):
  147. return self._uuids(self.train_split)
  148. @property
  149. def test_uuids(self):
  150. return self._uuids(self.test_split)