|
@@ -83,10 +83,15 @@ def main(args):
|
|
|
im, parts, label = data[i]
|
|
|
n_parts = len(parts)
|
|
|
|
|
|
- assert n_parts != 0
|
|
|
- rows = int(np.ceil(np.sqrt(n_parts)))
|
|
|
- cols = int(np.ceil(n_parts / rows))
|
|
|
- factor = 3 if args.rnd else 2
|
|
|
+ if args.no_parts:
|
|
|
+ cols, rows = 1, 1
|
|
|
+ factor = 1
|
|
|
+ else:
|
|
|
+ assert n_parts != 0
|
|
|
+ rows = int(np.ceil(np.sqrt(n_parts)))
|
|
|
+ cols = int(np.ceil(n_parts / rows))
|
|
|
+ factor = 3 if args.rnd else 2
|
|
|
+
|
|
|
grid_spec = plt.GridSpec(rows, factor * cols)
|
|
|
|
|
|
fig = plt.figure()
|