display.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. """
  4. Possible calls:
  5. ./display.sh /home/korsch1/korsch/datasets/birds/cub200_11 --dataset cub -s600 -n5 --features /home/korsch1/korsch/datasets/birds/features/{train,val}_16parts_gt.npz --ratio 0.31
  6. > displays GT parts of CUB200
  7. ./display.sh /home/korsch1/korsch/datasets/birds/NAC/2017-bilinear/ --dataset cub -s600 -n5 --features /home/korsch1/korsch/datasets/birds/features/{train,val}_16parts_gt.npz --ratio 0.31 --rescale_size 227
  8. > displays NAC parts of CUB200
  9. """
  10. from argparse import ArgumentParser
  11. import logging
  12. import numpy as np
  13. from annotations import NAB_Annotations, CUB_Annotations
  14. from dataset import Dataset
  15. from dataset.utils import reveal_parts, uniform_parts, \
  16. random_select, \
  17. visible_part_locs, visible_crops
  18. import matplotlib.pyplot as plt
  19. def init_logger(args):
  20. fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
  21. logging.basicConfig(
  22. format=fmt,
  23. level=getattr(logging, args.loglevel.upper(), logging.DEBUG),
  24. filename=args.logfile or None,
  25. filemode="w")
  26. def plot_crops(crops, title, scatter_mid=False):
  27. fig = plt.figure(figsize=(16,9))
  28. fig.suptitle(title, fontsize=16)
  29. n_crops = crops.shape[0]
  30. rows = int(np.ceil(np.sqrt(n_crops)))
  31. cols = int(np.ceil(n_crops / rows))
  32. for j, crop in enumerate(crops, 1):
  33. ax = fig.add_subplot(rows, cols, j)
  34. ax.imshow(crop)
  35. ax.axis("off")
  36. if scatter_mid:
  37. middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
  38. ax.scatter(middle_w, middle_h, marker="x")
  39. def main(args):
  40. init_logger(args)
  41. annotation_cls = dict(
  42. nab=NAB_Annotations,
  43. cub=CUB_Annotations)
  44. logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
  45. annot = annotation_cls.get(args.dataset.lower())(args.data)
  46. subset = args.subset.lower()
  47. uuids = getattr(annot, "{}_uuids".format(subset))
  48. features = args.features[0 if subset == "train" else 1]
  49. data = Dataset(
  50. uuids=uuids, annotations=annot,
  51. part_rescale_size=args.rescale_size,
  52. features=features,
  53. uniform_parts=args.uniform_parts,
  54. crop_to_bb=args.crop_to_bb,
  55. crop_uniform=args.crop_uniform,
  56. parts_in_bb=args.parts_in_bb,
  57. rnd_select=args.rnd,
  58. ratio=args.ratio,
  59. seed=args.seed
  60. )
  61. n_images = len(data)
  62. logging.info("Found {} images in the {} subset".format(n_images, subset))
  63. for i in range(n_images):
  64. if i + 1 <= args.start: continue
  65. im, parts, label = data[i]
  66. idxs, xy = visible_part_locs(parts)
  67. part_crops = visible_crops(im, parts, ratio=args.ratio)
  68. if args.rnd:
  69. selected = parts[:, -1].astype(bool)
  70. parts[selected, -1] = 0
  71. parts[np.logical_not(selected), -1] = 1
  72. action_crops = visible_crops(im, parts, ratio=args.ratio)
  73. logging.debug(label)
  74. logging.debug(idxs)
  75. logging.debug(xy)
  76. fig1 = plt.figure(figsize=(16,9))
  77. ax = fig1.add_subplot(2,1,1)
  78. ax.imshow(im)
  79. ax.set_title("Visible Parts")
  80. ax.scatter(*xy, marker="x", c=idxs)
  81. ax.axis("off")
  82. ax = fig1.add_subplot(2,1,2)
  83. ax.set_title("{}selected parts".format("randomly " if args.rnd else ""))
  84. ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
  85. # ax.scatter(*xy, marker="x", c=idxs)
  86. ax.axis("off")
  87. plot_crops(part_crops, "Selected parts")
  88. if args.rnd:
  89. plot_crops(action_crops, "Actions")
  90. plt.show()
  91. plt.close()
  92. if i+1 >= args.start + args.n_images: break
  93. parser = ArgumentParser()
  94. parser.add_argument("data",
  95. help="Folder containing the dataset with images and annotation files",
  96. type=str)
  97. parser.add_argument("--dataset",
  98. help="Possible datasets: NAB, CUB",
  99. choices=["cub", "nab"],
  100. default="nab", type=str)
  101. parser.add_argument("--features",
  102. help="pre-extracted train and test features",
  103. default=[None, None],
  104. nargs=2, type=str)
  105. parser.add_argument("--subset",
  106. help="Possible subsets: train, test",
  107. choices=["train", "test"],
  108. default="train", type=str)
  109. parser.add_argument("--start", "-s",
  110. help="Image id to start with",
  111. type=int, default=0)
  112. parser.add_argument("--n_images", "-n",
  113. help="Number of images to display",
  114. type=int, default=10)
  115. parser.add_argument("--ratio",
  116. help="Part extraction ratio",
  117. type=float, default=.2)
  118. parser.add_argument("--rescale_size",
  119. help="rescales the part positions from this size to original image size",
  120. type=int, default=-1)
  121. parser.add_argument("--rnd",
  122. help="select random subset of present parts",
  123. action="store_true")
  124. parser.add_argument("--uniform_parts", "-u",
  125. help="Do not use GT parts, but sample parts uniformly from the image",
  126. action="store_true")
  127. parser.add_argument("--crop_to_bb",
  128. help="Crop image to the bounding box",
  129. action="store_true")
  130. parser.add_argument("--crop_uniform",
  131. help="Try to extend the bounding box to same height and width",
  132. action="store_true")
  133. parser.add_argument("--parts_in_bb",
  134. help="Only display parts, that are inside the bounding box",
  135. action="store_true")
  136. parser.add_argument(
  137. '--logfile', type=str, default='',
  138. help='File for logging output')
  139. parser.add_argument(
  140. '--loglevel', type=str, default='INFO',
  141. help='logging level. see logging module for more information')
  142. parser.add_argument(
  143. '--seed', type=int, default=12311123,
  144. help='random seed')
  145. main(parser.parse_args())