Ver código fonte

some changes in part selection methods

Dimitri Korsch 6 anos atrás
pai
commit
43a25788bc
3 arquivos alterados com 38 adições e 19 exclusões
  1. 1 1
      nabirds/__init__.py
  2. 9 5
      nabirds/dataset/image.py
  3. 28 13
      nabirds/display.py

+ 1 - 1
nabirds/__init__.py

@@ -1,4 +1,4 @@
 from .dataset import Dataset
 from .annotations import NAB_Annotations, CUB_Annotations
 
-__version__ = "0.1.6"
+__version__ = "0.1.7"

+ 9 - 5
nabirds/dataset/image.py

@@ -58,17 +58,21 @@ class ImageWrapper(object):
 		return result
 
 	@should_have_parts
-	def select_random_parts(self, rnd, n_parts):
-
-		idxs, xy = self.visible_part_locs()
-		rnd_idxs = utils.random_idxs(idxs, rnd=rnd, n_parts=n_parts)
+	def select_parts(self, idxs):
 		result = self.copy()
 
 		result.parts[:, -1] = 0
-		result.parts[rnd_idxs, -1] = 1
+		result.parts[idxs, -1] = 1
 
 		return result
 
+	@should_have_parts
+	def select_random_parts(self, rnd, n_parts):
+
+		idxs, xy = self.visible_part_locs()
+		rnd_idxs = utils.random_idxs(idxs, rnd=rnd, n_parts=n_parts)
+		return self.select_parts(rnd_idxs)
+
 	@should_have_parts
 	def visible_crops(self, ratio):
 		return utils.visible_crops(self.im, self.parts, ratio=ratio)

+ 28 - 13
nabirds/display.py

@@ -22,6 +22,25 @@ def init_logger(args):
 		filename=args.logfile or None,
 		filemode="w")
 
+def plot_crops(crops, title, scatter_mid=False):
+
+	fig = plt.figure(figsize=(16,9))
+	fig.suptitle(title, fontsize=16)
+
+	n_crops = crops.shape[0]
+	rows = int(np.ceil(np.sqrt(n_crops)))
+	cols = int(np.ceil(n_crops / rows))
+
+	for j, crop in enumerate(crops, 1):
+		ax = fig.add_subplot(rows, cols, j)
+		ax.imshow(crop)
+		ax.axis("off")
+		if scatter_mid:
+			middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
+			ax.scatter(middle_w, middle_h, marker="x")
+
+
+
 def main(args):
 	init_logger(args)
 
@@ -57,6 +76,11 @@ def main(args):
 
 		idxs, xy = visible_part_locs(parts)
 		part_crops = visible_crops(im, parts, ratio=args.ratio)
+		if args.rnd:
+			selected = parts[:, -1].astype(bool)
+			parts[selected, -1] = 0
+			parts[np.logical_not(selected), -1] = 1
+			action_crops = visible_crops(im, parts, ratio=args.ratio)
 
 		logging.debug(label)
 		logging.debug(idxs)
@@ -72,22 +96,13 @@ def main(args):
 		ax = fig1.add_subplot(2,1,2)
 		ax.set_title("{}selected parts".format("randomly " if args.rnd else ""))
 		ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
-		ax.scatter(*xy, marker="x", c=idxs)
+		# ax.scatter(*xy, marker="x", c=idxs)
 		ax.axis("off")
 
-		fig = plt.figure(figsize=(16,9))
-
-		n_crops = part_crops.shape[0]
-		rows = int(np.ceil(np.sqrt(n_crops)))
-		cols = int(np.ceil(n_crops / rows))
+		plot_crops(part_crops, "Selected parts")
 
-		for j, crop in enumerate(part_crops, 1):
-			ax = fig.add_subplot(rows, cols, j)
-			ax.imshow(crop)
-			ax.axis("off")
-
-			middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
-			ax.scatter(middle_w, middle_h, marker="x")
+		if args.rnd:
+			plot_crops(action_crops, "Actions")
 
 		plt.show()
 		plt.close()