Răsfoiți Sursa

added part name loading. some fixes in the part coordinates handling

Dimitri Korsch 6 ani în urmă
părinte
comite
42ef628cc3
4 a modificat fișierele cu 37 adăugiri și 11 ștergeri
  1. 16 1
      nabirds/annotations.py
  2. 6 1
      nabirds/dataset/__init__.py
  3. 11 4
      nabirds/dataset/utils.py
  4. 4 5
      nabirds/display.py

+ 16 - 1
nabirds/annotations.py

@@ -1,6 +1,6 @@
 from os.path import join, isfile
 import numpy as np
-from collections import defaultdict
+from collections import defaultdict, OrderedDict
 import abc
 import warnings
 
@@ -61,6 +61,17 @@ class BaseAnnotations(abc.ABC):
 
 		self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids]).astype(int)
 
+		if hasattr(self, "_part_names") and self._part_names is not None:
+			self._load_part_names()
+
+	def _load_part_names(self):
+		self.part_names = OrderedDict()
+		self.part_name_list = []
+		for line in self._part_names:
+			part_idx, _, name = line.partition(" ")
+			self.part_names[int(part_idx)] = name
+			self.part_name_list.append(name)
+
 	def _load_split(self):
 		assert self._split is not None, "Train-test split was not loaded!"
 		uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
@@ -102,6 +113,7 @@ class NAB_Annotations(BaseAnnotations):
 			hierarchy_file="hierarchy.txt",
 			split_file="train_test_split.txt",
 			parts_file=join("parts", "part_locs.txt"),
+			part_names_file=join("parts", "parts.txt"),
 		)
 
 		info.structure = [
@@ -110,6 +122,7 @@ class NAB_Annotations(BaseAnnotations):
 			[info.hierarchy_file, "hierarchy"],
 			[info.split_file, "_split"],
 			[info.parts_file, "_part_locs"],
+			[info.part_names_file, "_part_names"],
 		]
 		return info
 
@@ -122,6 +135,7 @@ class CUB_Annotations(BaseAnnotations):
 			labels_file="labels.txt",
 			split_file="tr_ID.txt",
 			parts_file=join("parts", "part_locs.txt"),
+			part_names_file=join("parts", "parts.txt"),
 		)
 
 		info.structure = [
@@ -129,6 +143,7 @@ class CUB_Annotations(BaseAnnotations):
 			[info.labels_file, "labels"],
 			[info.split_file, "_split"],
 			[info.parts_file, "_part_locs"],
+			[info.part_names_file, "_part_names"],
 		]
 		return info
 

+ 6 - 1
nabirds/dataset/__init__.py

@@ -19,6 +19,11 @@ class Dataset(object):
 	def get_example(self, i, mode="RGB"):
 		methods = ["image", "parts", "label"]
 		im_path, parts, label = [self._get(m, i) for m in methods]
-		return imread(im_path, pilmode=mode), parts, label
+		im = imread(im_path, pilmode=mode)
+		h,w,c = im.shape
+		# fit to the dimensions of the image
+		parts[:, 1] = np.minimum(parts[:, 1], w - 1)
+		parts[:, 2] = np.minimum(parts[:, 2], h - 1)
+		return im, parts, label
 
 	__getitem__  = get_example

+ 11 - 4
nabirds/dataset/utils.py

@@ -65,16 +65,23 @@ def reveal_parts(im, xy, ratio=DEFAULT_RATIO):
 
 	return res
 
+def select(crops, mask):
+	selected = np.zeros_like(crops)
+	selected[mask] = crops[mask]
+	return selected
+
+def selection_mask(idxs, n):
+	return np.bincount(idxs, minlength=n).astype(bool)
+
 def random_select(idxs, xy, part_crops, *args, **kw):
 	rnd_idxs = random_idxs(np.arange(len(idxs)), *args, **kw)
 	idxs = idxs[rnd_idxs]
 	xy = xy[:, rnd_idxs]
 
-	selected_mask = np.bincount(idxs, minlength=len(part_crops)).astype(bool)
-	p_crops = part_crops.copy()
-	p_crops[np.logical_not(selected_mask)] = 0
+	mask = selection_mask(idxs, len(part_crops))
+	selected_crops = select(part_crops, mask)
 
-	return idxs, xy, p_crops
+	return idxs, xy, selected_crops
 
 def random_idxs(idxs, rnd=None, n_parts=None):
 

+ 4 - 5
nabirds/display.py

@@ -39,9 +39,9 @@ def main(args):
 
 	for i in range(n_images):
 		if i + 1 <= args.start: continue
-
 		im, parts, label = data[i]
 
+
 		if args.uniform_parts:
 			parts = uniform_parts(im, ratio=args.ratio)
 
@@ -68,14 +68,14 @@ def main(args):
 		ax.scatter(*xy, marker="x", c=idxs)
 		ax.axis("off")
 
-		fig2 = plt.figure(figsize=(16,9))
+		fig = plt.figure(figsize=(16,9))
 
 		n_crops = part_crops.shape[0]
 		rows = int(np.ceil(np.sqrt(n_crops)))
 		cols = int(np.ceil(n_crops / rows))
 
 		for j, crop in enumerate(part_crops, 1):
-			ax = fig2.add_subplot(rows, cols, j)
+			ax = fig.add_subplot(rows, cols, j)
 			ax.imshow(crop)
 			ax.axis("off")
 
@@ -83,8 +83,7 @@ def main(args):
 			ax.scatter(middle_w, middle_h, marker="x")
 
 		plt.show()
-		plt.close(fig1)
-		plt.close(fig2)
+		plt.close()
 
 		if i+1 >= args.start + args.n_images: break