Преглед изворни кода

updated old display script

Dimitri Korsch пре 6 година
родитељ
комит
2b734c06bd
2 измењених фајлова са 43 додато и 46 уклоњено
  1. 3 0
      nabirds/annotations/base.py
  2. 40 46
      scripts/display.py

+ 3 - 0
nabirds/annotations/base.py

@@ -76,6 +76,9 @@ class BaseAnnotations(abc.ABC):
 
 	def check_parts_and_features(self, subset, **kwargs):
 		dataset_info = self.dataset_info
+		if dataset_info is None:
+			return kwargs
+
 		# TODO: pass all scales
 		new_opts = {
 			"ratio": dataset_info.scales[0],

+ 40 - 46
scripts/display.py

@@ -18,12 +18,7 @@ from argparse import ArgumentParser
 import logging
 import numpy as np
 
-from annotations import NAB_Annotations, CUB_Annotations
-from dataset import Dataset
-from dataset.utils import reveal_parts, uniform_parts, \
-	random_select, \
-	visible_part_locs, visible_crops
-
+from nabirds.annotations import NAB_Annotations, CUB_Annotations
 
 import matplotlib.pyplot as plt
 
@@ -37,25 +32,22 @@ def init_logger(args):
 
 def plot_crops(crops, title, scatter_mid=False, names=None):
 
-	fig = plt.figure(figsize=(16,9))
-	fig.suptitle(title, fontsize=16)
-
 	n_crops = crops.shape[0]
 	rows = int(np.ceil(np.sqrt(n_crops)))
 	cols = int(np.ceil(n_crops / rows))
 
-	for j, crop in enumerate(crops, 1):
-		ax = fig.add_subplot(rows, cols, j)
+	fig, axs = plt.subplots(rows, cols, figsize=(16,9))
+	fig.suptitle(title, fontsize=16)
+	for i, crop in enumerate(crops):
+		ax = axs[np.unravel_index(i, axs.shape)]
 		if names is not None:
-			ax.set_title(names[j-1])
+			ax.set_title(names[i])
 		ax.imshow(crop)
 		ax.axis("off")
 		if scatter_mid:
 			middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
 			ax.scatter(middle_w, middle_h, marker="x")
 
-
-
 def main(args):
 	init_logger(args)
 
@@ -71,10 +63,10 @@ def main(args):
 	uuids = getattr(annot, "{}_uuids".format(subset))
 	features = args.features[0 if subset == "train" else 1]
 
-	data = Dataset(
-		uuids=uuids, annotations=annot,
+	data = annot.new_dataset(
+		subset,
 		part_rescale_size=args.rescale_size,
-		features=features,
+		# features=features,
 
 		uniform_parts=args.uniform_parts,
 
@@ -91,45 +83,45 @@ def main(args):
 	n_images = len(data)
 	logging.info("Found {} images in the {} subset".format(n_images, subset))
 
-	for i in range(n_images):
-		if i + 1 <= args.start: continue
+	start = max(args.start, 0)
+	n_images = min(args.n_images, len(data) - start)
+
+	for i in range(start, max(start, start + n_images)):
+
 		im, parts, label = data[i]
 
-		idxs, xy = visible_part_locs(parts)
-		part_crops = visible_crops(im, parts, ratio=args.ratio)
+		fig1, axs = plt.subplots(2, 1, figsize=(16,9))
+
+		axs[0].axis("off")
+		axs[0].set_title("Visible Parts")
+		axs[0].imshow(im)
+
+		if not args.crop_to_bb:
+			data.plot_bounding_box(i, axs[0])
+		parts.plot(im=im, ax=axs[0], ratio=data.ratio)
+
+		axs[1].axis("off")
+		axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
+		axs[1].imshow(parts.reveal(im, ratio=data.ratio))
+
+		if data.uniform_parts:
+			crop_names = None
+		else:
+			crop_names = list(data._annot.part_names.values())
+
+		part_crops = parts.visible_crops(im, ratio=args.ratio)
 		if args.rnd:
-			selected = parts[:, -1].astype(bool)
-			parts[selected, -1] = 0
-			parts[np.logical_not(selected), -1] = 1
-			action_crops = visible_crops(im, parts, ratio=args.ratio)
-
-		logging.debug(label)
-		logging.debug(idxs)
-		logging.debug(xy)
-
-		fig1 = plt.figure(figsize=(16,9))
-		ax = fig1.add_subplot(2,1,1)
-		ax.imshow(im)
-		ax.set_title("Visible Parts")
-		ax.scatter(*xy, marker="x", c=idxs)
-		ax.axis("off")
+			parts.invert_selection()
+			action_crops = parts.visible_crops(im, ratio=args.ratio)
 
-		ax = fig1.add_subplot(2,1,2)
-		ax.set_title("{}selected parts".format("randomly " if args.rnd else ""))
-		ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
-		# ax.scatter(*xy, marker="x", c=idxs)
-		ax.axis("off")
-		crop_names = list(data._annot.part_names.values())
 		plot_crops(part_crops, "Selected parts", names=crop_names)
 
 		if args.rnd:
-			plot_crops(action_crops, "Actions")
+			plot_crops(action_crops, "Actions", names=crop_names)
 
 		plt.show()
 		plt.close()
 
-		if i+1 >= args.start + args.n_images: break
-
 parser = ArgumentParser()
 
 parser.add_argument("data",
@@ -139,7 +131,7 @@ parser.add_argument("data",
 parser.add_argument("--dataset",
 	help="Possible datasets: NAB, CUB",
 	choices=["cub", "nab"],
-	default="nab", type=str)
+	default="cub", type=str)
 
 parser.add_argument("--features",
 	help="pre-extracted train and test features",
@@ -151,6 +143,7 @@ parser.add_argument("--subset",
 	choices=["train", "test"],
 	default="train", type=str)
 
+
 parser.add_argument("--start", "-s",
 	help="Image id to start with",
 	type=int, default=0)
@@ -159,6 +152,7 @@ parser.add_argument("--n_images", "-n",
 	help="Number of images to display",
 	type=int, default=10)
 
+
 parser.add_argument("--ratio",
 	help="Part extraction ratio",
 	type=float, default=.2)