base.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from os.path import join, isfile, isdir
  2. import numpy as np
  3. from collections import defaultdict, OrderedDict
  4. import abc
  5. import warnings
  6. try:
  7. from yaml import CLoader as Loader, CDumper as Dumper
  8. except ImportError:
  9. from yaml import Loader, Dumper
  10. import yaml
  11. import simplejson as json
  12. from nabirds.utils import attr_dict
  13. from nabirds.dataset import Dataset
  14. class BaseAnnotations(abc.ABC):
  15. def __init__(self, root_or_infofile, parts=None, feature_model=None):
  16. super(BaseAnnotations, self).__init__()
  17. self.part_type = parts
  18. self.feature_model = feature_model
  19. if isdir(root_or_infofile):
  20. self.info = None
  21. self.root = root_or_infofile
  22. elif isfile(root_or_infofile):
  23. self.root = self.root_from_infofile(root_or_infofile, parts, feature_model)
  24. else:
  25. raise ValueError("Root folder or info file does not exist: \"{}\"".format(
  26. root_or_infofile
  27. ))
  28. for fname, attr in self.meta.structure:
  29. self.read_content(fname, attr)
  30. self.labels = np.array([int(l) for l in self.labels], dtype=np.int32)
  31. self._load_uuids()
  32. self._load_parts()
  33. self._load_split()
  34. @property
  35. def data_root(self):
  36. if self.info is None: return None
  37. return join(self.info.BASE_DIR, self.info.DATA_DIR)
  38. @property
  39. def dataset_info(self):
  40. if self.info is None: return None
  41. if self.part_type is None:
  42. return self.info.DATASETS[self.__class__.name]
  43. else:
  44. return self.info.PARTS[self.part_type]
  45. def root_from_infofile(self, info_file, parts=None, feature_model=None):
  46. with open(info_file) as f:
  47. self.info = attr_dict(yaml.load(f, Loader=Loader))
  48. dataset_info = self.dataset_info
  49. # print(json.dumps(dataset_info, indent=2))
  50. annot_dir = join(self.data_root, dataset_info.folder, dataset_info.annotations)
  51. assert isdir(annot_dir), "Annotation folder does exist! \"{}\"".format(annot_dir)
  52. return annot_dir
  53. def new_dataset(self, subset=None, dataset_cls=Dataset, **kwargs):
  54. if subset is not None:
  55. uuids = getattr(self, "{}_uuids".format(subset))
  56. else:
  57. uuids = self.uuids
  58. kwargs = self.check_parts_and_features(subset, **kwargs)
  59. return dataset_cls(uuids=uuids, annotations=self, **kwargs)
  60. def check_parts_and_features(self, subset, **kwargs):
  61. dataset_info = self.dataset_info
  62. if dataset_info is None:
  63. return kwargs
  64. # TODO: pass all scales
  65. new_opts = {
  66. "ratio": dataset_info.scales[0],
  67. "uniform_parts": dataset_info.is_uniform
  68. }
  69. if self.part_type is not None:
  70. new_opts["part_rescale_size"] = dataset_info.rescale_size
  71. if None not in [subset, self.feature_model]:
  72. features = "{subset}_{suffix}.{model}.npz".format(
  73. subset=subset,
  74. suffix=dataset_info.feature_suffix,
  75. model=self.feature_model)
  76. feature_path = join(self.root, "features", features)
  77. assert isfile(feature_path), \
  78. "Features do not exist: \"{}\"".format(feature_path)
  79. new_opts["features"] = feature_path
  80. new_opts.update(kwargs)
  81. print(new_opts)
  82. return new_opts
  83. @property
  84. @abc.abstractmethod
  85. def meta(self):
  86. pass
  87. def _path(self, file):
  88. return join(self.root, file)
  89. def _open(self, file):
  90. return open(self._path(file))
  91. def read_content(self, file, attr):
  92. content = None
  93. fpath = self._path(file)
  94. if isfile(fpath):
  95. with self._open(file) as f:
  96. content = [line.strip() for line in f if line.strip()]
  97. else:
  98. warnings.warn("File \"{}\" was not found!".format(fpath))
  99. setattr(self, attr, content)
  100. def _load_uuids(self):
  101. assert self._images is not None, "Images were not loaded!"
  102. uuid_fnames = [i.split() for i in self._images]
  103. self.uuids, self.images = map(np.array, zip(*uuid_fnames))
  104. self.uuid_to_idx = {uuid: i for i, uuid in enumerate(self.uuids)}
  105. def _load_parts(self):
  106. assert self._part_locs is not None, "Part locations were not loaded!"
  107. # this part is quite slow... TODO: some runtime improvements?
  108. uuid_to_parts = defaultdict(list)
  109. for content in [i.split() for i in self._part_locs]:
  110. uuid_to_parts[content[0]].append([float(i) for i in content[1:]])
  111. self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids]).astype(int)
  112. if hasattr(self, "_part_names") and self._part_names is not None:
  113. self._load_part_names()
  114. def _load_part_names(self):
  115. self.part_names = OrderedDict()
  116. self.part_name_list = []
  117. for line in self._part_names:
  118. part_idx, _, name = line.partition(" ")
  119. self.part_names[int(part_idx)] = name
  120. self.part_name_list.append(name)
  121. def _load_split(self):
  122. assert self._split is not None, "Train-test split was not loaded!"
  123. uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
  124. self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
  125. self.test_split = np.logical_not(self.train_split)
  126. def image_path(self, image):
  127. return join(self.root, self.meta.images_folder, image)
  128. def image(self, uuid):
  129. fname = self.images[self.uuid_to_idx[uuid]]
  130. return self.image_path(fname)
  131. def label(self, uuid):
  132. return self.labels[self.uuid_to_idx[uuid]].copy()
  133. def parts(self, uuid):
  134. return self.part_locs[self.uuid_to_idx[uuid]].copy()
  135. def _uuids(self, split):
  136. return self.uuids[split]
  137. @property
  138. def train_uuids(self):
  139. return self._uuids(self.train_split)
  140. @property
  141. def test_uuids(self):
  142. return self._uuids(self.test_split)