瀏覽代碼

unified scripts

Dimitri Korsch 6 年之前
父節點
當前提交
867ed3c669
共有 2 個文件被更改,包括 33 次插入19 次删除
  1. 13 13
      scripts/display.py
  2. 20 6
      scripts/display_from_info.py

+ 13 - 13
scripts/display.py

@@ -58,7 +58,7 @@ def main(args):
 	logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
 	annot = annotation_cls.get(args.dataset.lower())(args.data)
 
-	subset = args.subset.lower()
+	subset = args.subset
 
 	uuids = getattr(annot, "{}_uuids".format(subset))
 	features = args.features[0 if subset == "train" else 1]
@@ -80,14 +80,13 @@ def main(args):
 		seed=args.seed
 
 	)
-	n_images = len(data)
-	logging.info("Found {} images in the {} subset".format(n_images, subset))
+
+	logging.info("Loaded {} {} images".format(len(data), subset))
 
 	start = max(args.start, 0)
 	n_images = min(args.n_images, len(data) - start)
 
 	for i in range(start, max(start, start + n_images)):
-
 		im, parts, label = data[i]
 
 		fig1, axs = plt.subplots(2, 1, figsize=(16,9))
@@ -95,7 +94,6 @@ def main(args):
 		axs[0].axis("off")
 		axs[0].set_title("Visible Parts")
 		axs[0].imshow(im)
-
 		if not args.crop_to_bb:
 			data.plot_bounding_box(i, axs[0])
 		parts.plot(im=im, ax=axs[0], ratio=data.ratio)
@@ -133,10 +131,6 @@ parser.add_argument("--dataset",
 	choices=["cub", "nab"],
 	default="cub", type=str)
 
-parser.add_argument("--features",
-	help="pre-extracted train and test features",
-	default=[None, None],
-	nargs=2, type=str)
 
 parser.add_argument("--subset",
 	help="Possible subsets: train, test",
@@ -153,6 +147,11 @@ parser.add_argument("--n_images", "-n",
 	type=int, default=10)
 
 
+parser.add_argument("--features",
+	help="pre-extracted train and test features",
+	default=[None, None],
+	nargs=2, type=str)
+
 parser.add_argument("--ratio",
 	help="Part extraction ratio",
 	type=float, default=.2)
@@ -161,14 +160,15 @@ parser.add_argument("--rescale_size",
 	help="rescales the part positions from this size to original image size",
 	type=int, default=-1)
 
-parser.add_argument("--rnd",
-	help="select random subset of present parts",
-	action="store_true")
-
 parser.add_argument("--uniform_parts", "-u",
 	help="Do not use GT parts, but sample parts uniformly from the image",
 	action="store_true")
 
+
+parser.add_argument("--rnd",
+	help="select random subset of present parts",
+	action="store_true")
+
 parser.add_argument("--crop_to_bb",
 	help="Crop image to the bounding box",
 	action="store_true")

+ 20 - 6
scripts/display_from_info.py

@@ -16,7 +16,7 @@ from matplotlib.patches import Rectangle
 
 from argparse import ArgumentParser
 
-from nabirds import CUB_Annotations
+from nabirds import NAB_Annotations, CUB_Annotations
 
 def init_logger(args):
 	fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
@@ -49,15 +49,22 @@ def plot_crops(crops, title, scatter_mid=False, names=None):
 def main(args):
 	init_logger(args)
 
-	annot = CUB_Annotations(
+	annotation_cls = dict(
+		nab=NAB_Annotations,
+		cub=CUB_Annotations)
+
+	logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
+	annot = annotation_cls.get(args.dataset.lower())(
 		args.info, args.parts, args.feature_model)
 
 	logging.info("Loaded data from \"{}\"".format(annot.root))
 
-	uuids = getattr(annot, "{}_uuids".format(args.subset))
+	subset = args.subset
+
+	uuids = getattr(annot, "{}_uuids".format(subset))
 
 	data = annot.new_dataset(
-		args.subset,
+		subset,
 
 		crop_to_bb=args.crop_to_bb,
 		crop_uniform=args.crop_uniform,
@@ -68,7 +75,7 @@ def main(args):
 		seed=args.seed
 	)
 
-	logging.info("Loaded {} {} images".format(len(data), args.subset))
+	logging.info("Loaded {} {} images".format(len(data), subset))
 
 	start = max(args.start, 0)
 	n_images = min(args.n_images, len(data) - start)
@@ -106,11 +113,18 @@ def main(args):
 		plt.show()
 		plt.close()
 
-
 parser = ArgumentParser()
 
 parser.add_argument("info")
 
+
+parser.add_argument("--dataset",
+	help="Possible datasets: NAB, CUB",
+	choices=["cub", "nab"],
+	default="cub", type=str
+)
+
+
 parser.add_argument("--parts", "-p",
 	choices=["GT", "GT2", "NAC", "UNI", "L1_pred", "L1_full"]
 )