display_from_info.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import sys
  4. sys.path.insert(0, "..")
  5. try:
  6. from yaml import CLoader as Loader, CDumper as Dumper
  7. except ImportError:
  8. from yaml import Loader, Dumper
  9. import yaml
  10. import logging
  11. import numpy as np
  12. import matplotlib.pyplot as plt
  13. from matplotlib.patches import Rectangle
  14. from argparse import ArgumentParser
  15. from nabirds import CUB_Annotations
  16. def init_logger(args):
  17. fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
  18. logging.basicConfig(
  19. format=fmt,
  20. level=getattr(logging, args.loglevel.upper(), logging.DEBUG),
  21. filename=args.logfile or None,
  22. filemode="w")
  23. def plot_crops(crops, title, scatter_mid=False, names=None):
  24. n_crops = crops.shape[0]
  25. rows = int(np.ceil(np.sqrt(n_crops)))
  26. cols = int(np.ceil(n_crops / rows))
  27. fig, axs = plt.subplots(rows, cols, figsize=(16,9))
  28. fig.suptitle(title, fontsize=16)
  29. for i, crop in enumerate(crops):
  30. ax = axs[np.unravel_index(i, axs.shape)]
  31. if names is not None:
  32. ax.set_title(names[i])
  33. ax.imshow(crop)
  34. ax.axis("off")
  35. if scatter_mid:
  36. middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
  37. ax.scatter(middle_w, middle_h, marker="x")
  38. def main(args):
  39. init_logger(args)
  40. annot = CUB_Annotations(
  41. args.info, args.parts, args.feature_model)
  42. logging.info("Loaded data from \"{}\"".format(annot.root))
  43. uuids = getattr(annot, "{}_uuids".format(args.subset))
  44. data = annot.new_dataset(
  45. args.subset,
  46. crop_to_bb=args.crop_to_bb,
  47. crop_uniform=args.crop_uniform,
  48. parts_in_bb=args.parts_in_bb,
  49. rnd_select=args.rnd,
  50. seed=args.seed
  51. )
  52. logging.info("Loaded {} {} images".format(len(data), args.subset))
  53. start = max(args.start, 0)
  54. n_images = min(args.n_images, len(data) - start)
  55. for i in range(start, max(start, start + n_images)):
  56. im, parts, label = data[i]
  57. fig1, axs = plt.subplots(2, 1, figsize=(16,9))
  58. axs[0].axis("off")
  59. axs[0].set_title("Visible Parts")
  60. axs[0].imshow(im)
  61. if not args.crop_to_bb:
  62. data.plot_bounding_box(i, axs[0])
  63. parts.plot(im=im, ax=axs[0], ratio=data.ratio)
  64. axs[1].axis("off")
  65. axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
  66. axs[1].imshow(parts.reveal(im, ratio=data.ratio))
  67. if data.uniform_parts:
  68. crop_names = None
  69. else:
  70. crop_names = list(data._annot.part_names.values())
  71. part_crops = parts.visible_crops(im, ratio=data.ratio)
  72. if args.rnd:
  73. parts.invert_selection()
  74. action_crops = parts.visible_crops(im, ratio=data.ratio)
  75. plot_crops(part_crops, "Selected parts", names=crop_names)
  76. if args.rnd:
  77. plot_crops(action_crops, "Actions", names=crop_names)
  78. plt.show()
  79. plt.close()
  80. parser = ArgumentParser()
  81. parser.add_argument("info")
  82. parser.add_argument("--parts", "-p",
  83. choices=["GT", "GT2", "NAC", "UNI", "L1_pred", "L1_full"]
  84. )
  85. parser.add_argument("--feature_model", "-fm",
  86. choices=["inception", "inception_tf", "resnet"]
  87. )
  88. parser.add_argument("--subset", "-sub",
  89. help="Possible subsets: train, test",
  90. choices=["train", "test"],
  91. default="train", type=str)
  92. parser.add_argument("--start", "-s",
  93. help="Image id to start with",
  94. type=int, default=0)
  95. parser.add_argument("--n_images", "-n",
  96. help="Number of images to display",
  97. type=int, default=10)
  98. parser.add_argument("--rnd",
  99. help="select random subset of present parts",
  100. action="store_true")
  101. parser.add_argument("--crop_to_bb",
  102. help="Crop image to the bounding box",
  103. action="store_true")
  104. parser.add_argument("--crop_uniform",
  105. help="Try to extend the bounding box to same height and width",
  106. action="store_true")
  107. parser.add_argument("--parts_in_bb",
  108. help="Only display parts, that are inside the bounding box",
  109. action="store_true")
  110. parser.add_argument(
  111. '--logfile', type=str, default='',
  112. help='File for logging output')
  113. parser.add_argument(
  114. '--loglevel', type=str, default='INFO',
  115. help='logging level. see logging module for more information')
  116. parser.add_argument(
  117. '--seed', type=int, default=12311123,
  118. help='random seed')
  119. main(parser.parse_args())