display.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import sys
  4. sys.path.insert(0, "..")
  5. import logging
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from argparse import ArgumentParser
  9. from cvdatasets import AnnotationType
  10. from utils import parser
  11. def plot_crops(crops, spec, spec_offset, scatter_mid=False, names=None):
  12. n_crops = len(crops)
  13. if n_crops == 0: return
  14. rows = int(np.ceil(np.sqrt(n_crops)))
  15. cols = int(np.ceil(n_crops / rows))
  16. dx, dy = spec_offset
  17. for i, crop in enumerate(crops):
  18. x, y = np.unravel_index(i, (rows, cols))
  19. ax = plt.subplot(spec[x+dx, y+dy])
  20. if names is not None:
  21. ax.set_title(names[i])
  22. ax.imshow(crop)
  23. ax.axis("off")
  24. if scatter_mid:
  25. middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
  26. ax.scatter(middle_w, middle_h, marker="x")
  27. def main(args):
  28. # assert args.dataset in AnnotationType, \
  29. # f"AnnotationType is not known: \"{args.dataset}\""
  30. annot = AnnotationType.new_annotation(args)
  31. kwargs = {}
  32. if annot.info is None:
  33. kwargs = dict(
  34. part_rescale_size=args.rescale_size,
  35. uniform_parts=args.uniform_parts,
  36. ratio=args.ratio,
  37. )
  38. data = annot.new_dataset(
  39. args.subset,
  40. center_cropped=not args.no_center_crop,
  41. crop_to_bb=args.crop_to_bb,
  42. crop_uniform=args.crop_uniform,
  43. parts_in_bb=args.parts_in_bb,
  44. rnd_select=args.rnd,
  45. seed=args.seed,
  46. **kwargs
  47. )
  48. logging.info(f"Loaded {len(data)} {args.subset} images")
  49. if args.only_class >= 0:
  50. mask = data.labels == args.only_class
  51. logging.info(f"Showing only {mask.sum()} images from class {args.only_class}")
  52. idxs = np.where(mask)[0]
  53. else:
  54. start = max(args.start, 0)
  55. n_images = min(args.n_images, len(data) - start)
  56. end = max(start, start + n_images)
  57. logging.info(f"Showing only images {start} - {end}")
  58. idxs = range(start, end)
  59. for i in idxs:
  60. im, parts, label = data[i]
  61. n_parts = len(parts)
  62. if args.no_parts:
  63. cols, rows = 1, 1
  64. factor = 1
  65. else:
  66. assert n_parts != 0
  67. rows = int(np.ceil(np.sqrt(n_parts)))
  68. cols = int(np.ceil(n_parts / rows))
  69. factor = 3 if args.rnd else 2
  70. grid_spec = plt.GridSpec(rows, factor * cols)
  71. fig = plt.figure()
  72. im_ax = plt.subplot(grid_spec[:, :cols])
  73. im_ax.axis("off")
  74. im_ax.set_title("Visible Parts")
  75. im_ax.imshow(im)
  76. if not args.crop_to_bb and not args.no_bboxes:
  77. data.plot_bounding_box(i, im_ax)
  78. # axs[1].axis("off")
  79. # axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
  80. # axs[1].imshow(parts.reveal(im, ratio=data.ratio))
  81. if not args.no_parts:
  82. parts.plot(im=im, ax=im_ax, ratio=data.ratio, linewidth=3)
  83. if data.uniform_parts:
  84. crop_names = None
  85. else:
  86. crop_names = list(data._annot.part_names.values())
  87. part_crops = parts.visible_crops(im, ratio=data.ratio)
  88. if args.rnd:
  89. parts.invert_selection()
  90. action_crops = parts.visible_crops(im, ratio=data.ratio)
  91. plot_crops(part_crops, grid_spec, spec_offset=(0, cols), names=crop_names)
  92. if args.rnd:
  93. plot_crops(action_crops, grid_spec, spec_offset=(0, 2*cols), names=crop_names)
  94. plt.show()
  95. plt.close()
  96. main(parser.parse_args())