소스 검색

added function to reveal certain parts

Dimitri Korsch 7 년 전
부모
커밋
5661e9f584
3개의 변경된 파일54개의 추가작업 그리고 21개의 파일을 삭제
  1. 19 11
      example_cub.py
  2. 16 7
      example_nab.py
  3. 19 3
      nabirds/dataset.py

+ 19 - 11
example_cub.py

@@ -2,7 +2,7 @@
 if __name__ != '__main__': raise Exception("Do not import me!")
 
 from nabirds import Dataset, CUB_Annotations
-from nabirds.dataset import visible_part_locs, visible_crops
+from nabirds.dataset import visible_part_locs, visible_crops, reveal_parts
 import matplotlib.pyplot as plt
 
 annot = CUB_Annotations(root="/home/korsch1/korsch/datasets/birds/cub200_11")
@@ -10,19 +10,27 @@ annot = CUB_Annotations(root="/home/korsch1/korsch/datasets/birds/cub200_11")
 print(annot.labels.shape)
 data = Dataset(annot.train_uuids, annot)
 
-for i, (im, parts, label) in enumerate(data, 1):
-	if i <= 15: continue
+start = 2000
+n_images = 5
 
-	idxs, (xs, ys) = visible_part_locs(parts)
+for i in range(len(data)):
+	if i+1 <= start: continue
 
-	print(label)
-	print(idxs)
+	im, parts, label = data[i]
+
+	idxs, xy = visible_part_locs(parts)
+
+	print(label, idxs)
 
 	fig1 = plt.figure(figsize=(16,9))
-	ax = fig1.add_subplot(111)
+	ax = fig1.add_subplot(2,1,1)
 
 	ax.imshow(im)
-	ax.scatter(xs, ys, marker="x", c=idxs)
+	ax.scatter(*xy, marker="x", c=idxs)
+
+	ax = fig1.add_subplot(2,1,2)
+	ax.imshow(reveal_parts(im, xy))
+	ax.scatter(*xy, marker="x", c=idxs)
 
 	fig2 = plt.figure(figsize=(16,9))
 	n_parts = parts.shape[0]
@@ -34,9 +42,9 @@ for i, (im, parts, label) in enumerate(data, 1):
 		middle = crop.shape[0] / 2
 		ax.scatter(middle, middle, marker="x")
 
+
 	plt.show()
-	plt.close(fig1)
-	plt.close(fig2)
+	plt.close()
 
-	if i >= 20: break
+	if i+1 >= start + n_images: break
 

+ 16 - 7
example_nab.py

@@ -2,7 +2,7 @@
 if __name__ != '__main__': raise Exception("Do not import me!")
 
 from nabirds import Dataset, NAB_Annotations
-from nabirds.dataset import visible_part_locs, visible_crops
+from nabirds.dataset import visible_part_locs, visible_crops, reveal_parts
 import matplotlib.pyplot as plt
 
 annot = NAB_Annotations("/home/korsch1/korsch/datasets/birds/nabirds")
@@ -10,19 +10,28 @@ annot = NAB_Annotations("/home/korsch1/korsch/datasets/birds/nabirds")
 print(annot.labels.shape)
 data = Dataset(annot.train_uuids, annot)
 
-for i, (im, parts, label) in enumerate(data, 1):
-	if i <= 15: continue
+start = 2000
+n_images = 5
 
-	idxs, (xs, ys) = visible_part_locs(parts)
+for i in range(len(data)):
+	if i+1 <= start: continue
+
+	im, parts, label = data[i]
+
+	idxs, xy = visible_part_locs(parts)
 
 	print(label)
 	print(idxs)
 
 	fig1 = plt.figure(figsize=(16,9))
-	ax = fig1.add_subplot(111)
+	ax = fig1.add_subplot(2,1,1)
 
 	ax.imshow(im)
-	ax.scatter(xs, ys, marker="x", c=idxs)
+	ax.scatter(*xy, marker="x", c=idxs)
+
+	ax = fig1.add_subplot(2,1,2)
+	ax.imshow(reveal_parts(im, xy))
+	ax.scatter(*xy, marker="x", c=idxs)
 
 	fig2 = plt.figure(figsize=(16,9))
 	n_parts = parts.shape[0]
@@ -38,5 +47,5 @@ for i, (im, parts, label) in enumerate(data, 1):
 	plt.close(fig1)
 	plt.close(fig2)
 
-	if i >= 20: break
+	if i+1 >= start + n_images: break
 

+ 19 - 3
nabirds/dataset.py

@@ -26,6 +26,8 @@ class Dataset(object):
 
 # some convention functions
 
+DEFAULT_RATIO = np.sqrt(49 / 400)
+
 def __expand_parts(p):
 	return p[:, 0], p[:, 1:3], p[:, 3].astype(bool)
 
@@ -33,21 +35,35 @@ def visible_part_locs(p):
 	idxs, locs, vis = __expand_parts(p)
 	return idxs[vis], locs[vis].T
 
-def visible_crops(im, p, ratio=np.sqrt(49 / 400), padding_mode="edge"):
+def visible_crops(im, p, ratio=DEFAULT_RATIO, padding_mode="edge"):
 	assert im.ndim == 3, "Only RGB images are currently supported!"
 	idxs, locs, vis = __expand_parts(p)
 	h, w, c = im.shape
-	crop_h = crop_w = int(np.sqrt(h*w) * ratio)
+	crop_h = crop_w = int(np.sqrt(h * w) * ratio)
 	crops = np.zeros((len(idxs), crop_h, crop_w, c), dtype=im.dtype)
 
 	padding = np.array([crop_h, crop_w]) // 2
 
 	padded_im = np.pad(im, [padding, padding, [0,0]], mode=padding_mode)
 
-	for i, loc, is_vis in zip(*__expand_parts(p)):
+	for i, loc, is_vis in zip(idxs, locs, vis):
 		if not is_vis: continue
 		x0, y0 = loc - crop_h // 2 + padding
 		crops[i] = padded_im[y0:y0+crop_h, x0:x0+crop_w]
 
 	return crops
 
+def reveal_parts(im, xy, ratio=DEFAULT_RATIO):
+	h, w, c = im.shape
+	crop_h = crop_w = int(np.sqrt(h * w) * ratio)
+
+	x0y0 = xy - crop_h // 2
+
+	res = np.zeros_like(im)
+	for x0, y0 in x0y0.T:
+		x1, y1 = x0 + crop_w, y0 + crop_w
+		x0, y0 = max(x0, 0), max(y0, 0)
+		res[y0:y0+crop_h, x0:x0+crop_w] = im[y0:y0+crop_h, x0:x0+crop_w]
+
+	return res
+