display.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import sys
  4. sys.path.insert(0, "..")
  5. import logging
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from argparse import ArgumentParser
  9. from nabirds import NAB_Annotations, CUB_Annotations
  10. def init_logger(args):
  11. fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
  12. logging.basicConfig(
  13. format=fmt,
  14. level=getattr(logging, args.loglevel.upper(), logging.DEBUG),
  15. filename=args.logfile or None,
  16. filemode="w")
  17. def plot_crops(crops, title, scatter_mid=False, names=None):
  18. n_crops = crops.shape[0]
  19. rows = int(np.ceil(np.sqrt(n_crops)))
  20. cols = int(np.ceil(n_crops / rows))
  21. fig, axs = plt.subplots(rows, cols, figsize=(16,9))
  22. fig.suptitle(title, fontsize=16)
  23. for i, crop in enumerate(crops):
  24. ax = axs[np.unravel_index(i, axs.shape)]
  25. if names is not None:
  26. ax.set_title(names[i])
  27. ax.imshow(crop)
  28. ax.axis("off")
  29. if scatter_mid:
  30. middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
  31. ax.scatter(middle_w, middle_h, marker="x")
  32. def main(args):
  33. init_logger(args)
  34. annotation_cls = dict(
  35. nab=NAB_Annotations,
  36. cub=CUB_Annotations)
  37. logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
  38. annot = annotation_cls.get(args.dataset.lower())(args.data)
  39. subset = args.subset
  40. uuids = getattr(annot, "{}_uuids".format(subset))
  41. features = args.features[0 if subset == "train" else 1]
  42. data = annot.new_dataset(
  43. subset,
  44. part_rescale_size=args.rescale_size,
  45. # features=features,
  46. uniform_parts=args.uniform_parts,
  47. crop_to_bb=args.crop_to_bb,
  48. crop_uniform=args.crop_uniform,
  49. parts_in_bb=args.parts_in_bb,
  50. rnd_select=args.rnd,
  51. ratio=args.ratio,
  52. seed=args.seed
  53. )
  54. logging.info("Loaded {} {} images".format(len(data), subset))
  55. start = max(args.start, 0)
  56. n_images = min(args.n_images, len(data) - start)
  57. for i in range(start, max(start, start + n_images)):
  58. im, parts, label = data[i]
  59. fig1, axs = plt.subplots(2, 1, figsize=(16,9))
  60. axs[0].axis("off")
  61. axs[0].set_title("Visible Parts")
  62. axs[0].imshow(im)
  63. if not args.crop_to_bb:
  64. data.plot_bounding_box(i, axs[0])
  65. parts.plot(im=im, ax=axs[0], ratio=data.ratio)
  66. axs[1].axis("off")
  67. axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
  68. axs[1].imshow(parts.reveal(im, ratio=data.ratio))
  69. if data.uniform_parts:
  70. crop_names = None
  71. else:
  72. crop_names = list(data._annot.part_names.values())
  73. part_crops = parts.visible_crops(im, ratio=data.ratio)
  74. if args.rnd:
  75. parts.invert_selection()
  76. action_crops = parts.visible_crops(im, ratio=data.ratio)
  77. plot_crops(part_crops, "Selected parts", names=crop_names)
  78. if args.rnd:
  79. plot_crops(action_crops, "Actions", names=crop_names)
  80. plt.show()
  81. plt.close()
  82. parser = ArgumentParser()
  83. parser.add_argument("data",
  84. help="Folder containing the dataset with images and annotation files",
  85. type=str)
  86. parser.add_argument("--dataset",
  87. help="Possible datasets: NAB, CUB",
  88. choices=["cub", "nab"],
  89. default="cub", type=str
  90. )
  91. parser.add_argument("--subset", "-sub",
  92. help="Possible subsets: train, test",
  93. choices=["train", "test"],
  94. default="train", type=str)
  95. parser.add_argument("--start", "-s",
  96. help="Image id to start with",
  97. type=int, default=0)
  98. parser.add_argument("--n_images", "-n",
  99. help="Number of images to display",
  100. type=int, default=10)
  101. parser.add_argument("--features",
  102. help="pre-extracted train and test features",
  103. default=[None, None],
  104. nargs=2, type=str)
  105. parser.add_argument("--ratio",
  106. help="Part extraction ratio",
  107. type=float, default=.2)
  108. parser.add_argument("--rescale_size",
  109. help="rescales the part positions from this size to original image size",
  110. type=int, default=-1)
  111. parser.add_argument("--uniform_parts", "-u",
  112. help="Do not use GT parts, but sample parts uniformly from the image",
  113. action="store_true")
  114. parser.add_argument("--rnd",
  115. help="select random subset of present parts",
  116. action="store_true")
  117. parser.add_argument("--crop_to_bb",
  118. help="Crop image to the bounding box",
  119. action="store_true")
  120. parser.add_argument("--crop_uniform",
  121. help="Try to extend the bounding box to same height and width",
  122. action="store_true")
  123. parser.add_argument("--parts_in_bb",
  124. help="Only display parts, that are inside the bounding box",
  125. action="store_true")
  126. parser.add_argument(
  127. '--logfile', type=str, default='',
  128. help='File for logging output')
  129. parser.add_argument(
  130. '--loglevel', type=str, default='INFO',
  131. help='logging level. see logging module for more information')
  132. parser.add_argument(
  133. '--seed', type=int, default=12311123,
  134. help='random seed')
  135. main(parser.parse_args())