|
@@ -1,6 +1,6 @@
|
|
from os.path import join, isfile
|
|
from os.path import join, isfile
|
|
import numpy as np
|
|
import numpy as np
|
|
-from collections import defaultdict
|
|
|
|
|
|
+from collections import defaultdict, OrderedDict
|
|
import abc
|
|
import abc
|
|
import warnings
|
|
import warnings
|
|
|
|
|
|
@@ -61,6 +61,17 @@ class BaseAnnotations(abc.ABC):
|
|
|
|
|
|
self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids]).astype(int)
|
|
self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids]).astype(int)
|
|
|
|
|
|
|
|
+ if hasattr(self, "_part_names") and self._part_names is not None:
|
|
|
|
+ self._load_part_names()
|
|
|
|
+
|
|
|
|
+ def _load_part_names(self):
|
|
|
|
+ self.part_names = OrderedDict()
|
|
|
|
+ self.part_name_list = []
|
|
|
|
+ for line in self._part_names:
|
|
|
|
+ part_idx, _, name = line.partition(" ")
|
|
|
|
+ self.part_names[int(part_idx)] = name
|
|
|
|
+ self.part_name_list.append(name)
|
|
|
|
+
|
|
def _load_split(self):
|
|
def _load_split(self):
|
|
assert self._split is not None, "Train-test split was not loaded!"
|
|
assert self._split is not None, "Train-test split was not loaded!"
|
|
uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
|
|
uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
|
|
@@ -102,6 +113,7 @@ class NAB_Annotations(BaseAnnotations):
|
|
hierarchy_file="hierarchy.txt",
|
|
hierarchy_file="hierarchy.txt",
|
|
split_file="train_test_split.txt",
|
|
split_file="train_test_split.txt",
|
|
parts_file=join("parts", "part_locs.txt"),
|
|
parts_file=join("parts", "part_locs.txt"),
|
|
|
|
+ part_names_file=join("parts", "parts.txt"),
|
|
)
|
|
)
|
|
|
|
|
|
info.structure = [
|
|
info.structure = [
|
|
@@ -110,6 +122,7 @@ class NAB_Annotations(BaseAnnotations):
|
|
[info.hierarchy_file, "hierarchy"],
|
|
[info.hierarchy_file, "hierarchy"],
|
|
[info.split_file, "_split"],
|
|
[info.split_file, "_split"],
|
|
[info.parts_file, "_part_locs"],
|
|
[info.parts_file, "_part_locs"],
|
|
|
|
+ [info.part_names_file, "_part_names"],
|
|
]
|
|
]
|
|
return info
|
|
return info
|
|
|
|
|
|
@@ -122,6 +135,7 @@ class CUB_Annotations(BaseAnnotations):
|
|
labels_file="labels.txt",
|
|
labels_file="labels.txt",
|
|
split_file="tr_ID.txt",
|
|
split_file="tr_ID.txt",
|
|
parts_file=join("parts", "part_locs.txt"),
|
|
parts_file=join("parts", "part_locs.txt"),
|
|
|
|
+ part_names_file=join("parts", "parts.txt"),
|
|
)
|
|
)
|
|
|
|
|
|
info.structure = [
|
|
info.structure = [
|
|
@@ -129,6 +143,7 @@ class CUB_Annotations(BaseAnnotations):
|
|
[info.labels_file, "labels"],
|
|
[info.labels_file, "labels"],
|
|
[info.split_file, "_split"],
|
|
[info.split_file, "_split"],
|
|
[info.parts_file, "_part_locs"],
|
|
[info.parts_file, "_part_locs"],
|
|
|
|
+ [info.part_names_file, "_part_names"],
|
|
]
|
|
]
|
|
return info
|
|
return info
|
|
|
|
|