Browse Source

updated dataset display script

Dimitri Korsch 4 years ago
parent
commit
f1347dbb5f
3 changed files with 44 additions and 36 deletions
  1. 43 10
      scripts/display.py
  2. 1 1
      scripts/display.sh
  3. 0 25
      scripts/utils/__init__.py

+ 43 - 10
scripts/display.py

@@ -10,7 +10,31 @@ import matplotlib.pyplot as plt
 from argparse import ArgumentParser
 
 from cvdatasets import AnnotationType
-from utils import parser, plot_crops
+from utils import parser
+
+def plot_crops(crops, spec, spec_offset, scatter_mid=False, names=None):
+
+	n_crops = len(crops)
+	if n_crops == 0: return
+	rows = int(np.ceil(np.sqrt(n_crops)))
+	cols = int(np.ceil(n_crops / rows))
+	dx, dy = spec_offset
+
+	for i, crop in enumerate(crops):
+		x, y = np.unravel_index(i, (rows, cols))
+
+		ax = plt.subplot(spec[x+dx, y+dy])
+
+		if names is not None:
+			ax.set_title(names[i])
+
+		ax.imshow(crop)
+		ax.axis("off")
+
+		if scatter_mid:
+			middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
+			ax.scatter(middle_w, middle_h, marker="x")
+
 
 def main(args):
 	# assert args.dataset in AnnotationType, \
@@ -54,25 +78,34 @@ def main(args):
 		logging.info(f"Showing only images {start} - {end}")
 		idxs = range(start, end)
 
+
 	for i in idxs:
 		im, parts, label = data[i]
+		n_parts = len(parts)
+
+		assert n_parts != 0
+		rows = int(np.ceil(np.sqrt(n_parts)))
+		cols = int(np.ceil(n_parts / rows))
+		factor = 3 if args.rnd else 2
+		grid_spec = plt.GridSpec(rows, factor * cols)
+
+		fig = plt.figure()
 
-		fig1, axs = plt.subplots(1, 1, figsize=(16,9))
-		axs = [axs]
+		im_ax = plt.subplot(grid_spec[:, :cols])
 
-		axs[0].axis("off")
-		axs[0].set_title("Visible Parts")
-		axs[0].imshow(im)
+		im_ax.axis("off")
+		im_ax.set_title("Visible Parts")
+		im_ax.imshow(im)
 
 		if not args.crop_to_bb and not args.no_bboxes:
-			data.plot_bounding_box(i, axs[0])
+			data.plot_bounding_box(i, im_ax)
 
 		# axs[1].axis("off")
 		# axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
 		# axs[1].imshow(parts.reveal(im, ratio=data.ratio))
 
 		if not args.no_parts:
-			parts.plot(im=im, ax=axs[0], ratio=data.ratio, linewidth=3)
+			parts.plot(im=im, ax=im_ax, ratio=data.ratio, linewidth=3)
 			if data.uniform_parts:
 				crop_names = None
 			else:
@@ -83,10 +116,10 @@ def main(args):
 				parts.invert_selection()
 				action_crops = parts.visible_crops(im, ratio=data.ratio)
 
-			plot_crops(part_crops, f"{args.parts}: Selected parts", names=crop_names)
+			plot_crops(part_crops, grid_spec, spec_offset=(0, cols), names=crop_names)
 
 			if args.rnd:
-				plot_crops(action_crops, f"{args.parts}: Actions", names=crop_names)
+				plot_crops(action_crops, grid_spec, spec_offset=(0, 2*cols), names=crop_names)
 
 		plt.show()
 		plt.close()

+ 1 - 1
scripts/display.sh

@@ -1,6 +1,6 @@
 #!/usr/bin/env bash
 source ${HOME}/.miniconda3/etc/profile.d/conda.sh
-conda activate ${ENV:-chainer6}
+conda activate ${CONDA_ENV:-chainer7cu11}
 
 PYTHON="python"
 

+ 0 - 25
scripts/utils/__init__.py

@@ -1,25 +0,0 @@
-import numpy as np
-import matplotlib.pyplot as plt
-
-
-def plot_crops(crops, title, scatter_mid=False, names=None):
-
-	n_crops = len(crops)
-	if n_crops == 0: return
-	rows = int(np.ceil(np.sqrt(n_crops)))
-	cols = int(np.ceil(n_crops / rows))
-
-	fig, axs = plt.subplots(rows, cols, figsize=(16,9))
-	fig.suptitle(title, fontsize=16)
-	[axs[np.unravel_index(i, axs.shape)].axis("off") for i in range(cols*rows)]
-
-	for i, crop in enumerate(crops):
-		ax = axs[np.unravel_index(i, axs.shape)]
-		if names is not None:
-			ax.set_title(names[i])
-		ax.imshow(crop)
-		if scatter_mid:
-			middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
-			ax.scatter(middle_w, middle_h, marker="x")
-
-