display_from_info.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import sys
  4. sys.path.insert(0, "..")
  5. import logging
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from argparse import ArgumentParser
  9. from nabirds import NAB_Annotations, CUB_Annotations
  10. def init_logger(args):
  11. fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
  12. logging.basicConfig(
  13. format=fmt,
  14. level=getattr(logging, args.loglevel.upper(), logging.DEBUG),
  15. filename=args.logfile or None,
  16. filemode="w")
  17. def plot_crops(crops, title, scatter_mid=False, names=None):
  18. n_crops = crops.shape[0]
  19. rows = int(np.ceil(np.sqrt(n_crops)))
  20. cols = int(np.ceil(n_crops / rows))
  21. fig, axs = plt.subplots(rows, cols, figsize=(16,9))
  22. fig.suptitle(title, fontsize=16)
  23. for i, crop in enumerate(crops):
  24. ax = axs[np.unravel_index(i, axs.shape)]
  25. if names is not None:
  26. ax.set_title(names[i])
  27. ax.imshow(crop)
  28. ax.axis("off")
  29. if scatter_mid:
  30. middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
  31. ax.scatter(middle_w, middle_h, marker="x")
  32. def main(args):
  33. init_logger(args)
  34. annotation_cls = dict(
  35. nab=NAB_Annotations,
  36. cub=CUB_Annotations)
  37. logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
  38. annot = annotation_cls.get(args.dataset.lower())(
  39. args.data, args.parts, args.feature_model)
  40. logging.info("Loaded data from \"{}\"".format(annot.root))
  41. subset = args.subset
  42. uuids = getattr(annot, "{}_uuids".format(subset))
  43. data = annot.new_dataset(
  44. subset,
  45. crop_to_bb=args.crop_to_bb,
  46. crop_uniform=args.crop_uniform,
  47. parts_in_bb=args.parts_in_bb,
  48. rnd_select=args.rnd,
  49. seed=args.seed
  50. )
  51. logging.info("Loaded {} {} images".format(len(data), subset))
  52. start = max(args.start, 0)
  53. n_images = min(args.n_images, len(data) - start)
  54. for i in range(start, max(start, start + n_images)):
  55. im, parts, label = data[i]
  56. fig1, axs = plt.subplots(2, 1, figsize=(16,9))
  57. axs[0].axis("off")
  58. axs[0].set_title("Visible Parts")
  59. axs[0].imshow(im)
  60. if not args.crop_to_bb:
  61. data.plot_bounding_box(i, axs[0])
  62. parts.plot(im=im, ax=axs[0], ratio=data.ratio)
  63. axs[1].axis("off")
  64. axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
  65. axs[1].imshow(parts.reveal(im, ratio=data.ratio))
  66. if data.uniform_parts:
  67. crop_names = None
  68. else:
  69. crop_names = list(data._annot.part_names.values())
  70. part_crops = parts.visible_crops(im, ratio=data.ratio)
  71. if args.rnd:
  72. parts.invert_selection()
  73. action_crops = parts.visible_crops(im, ratio=data.ratio)
  74. plot_crops(part_crops, "Selected parts", names=crop_names)
  75. if args.rnd:
  76. plot_crops(action_crops, "Actions", names=crop_names)
  77. plt.show()
  78. plt.close()
  79. parser = ArgumentParser()
  80. parser.add_argument("data",
  81. help="Folder containing the dataset with images and annotation files or dataset info file",
  82. type=str)
  83. parser.add_argument("--dataset",
  84. help="Possible datasets: NAB, CUB",
  85. choices=["cub", "nab"],
  86. default="cub", type=str
  87. )
  88. parser.add_argument("--parts", "-p",
  89. choices=["GT", "GT2", "NAC", "UNI", "L1_pred", "L1_full"]
  90. )
  91. parser.add_argument("--feature_model", "-fm",
  92. choices=["inception", "inception_tf", "resnet"]
  93. )
  94. parser.add_argument("--subset", "-sub",
  95. help="Possible subsets: train, test",
  96. choices=["train", "test"],
  97. default="train", type=str)
  98. parser.add_argument("--start", "-s",
  99. help="Image id to start with",
  100. type=int, default=0)
  101. parser.add_argument("--n_images", "-n",
  102. help="Number of images to display",
  103. type=int, default=10)
  104. parser.add_argument("--rnd",
  105. help="select random subset of present parts",
  106. action="store_true")
  107. parser.add_argument("--crop_to_bb",
  108. help="Crop image to the bounding box",
  109. action="store_true")
  110. parser.add_argument("--crop_uniform",
  111. help="Try to extend the bounding box to same height and width",
  112. action="store_true")
  113. parser.add_argument("--parts_in_bb",
  114. help="Only display parts, that are inside the bounding box",
  115. action="store_true")
  116. parser.add_argument(
  117. '--logfile', type=str, default='',
  118. help='File for logging output')
  119. parser.add_argument(
  120. '--loglevel', type=str, default='INFO',
  121. help='logging level. see logging module for more information')
  122. parser.add_argument(
  123. '--seed', type=int, default=12311123,
  124. help='random seed')
  125. main(parser.parse_args())