display.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. from argparse import ArgumentParser
  4. import logging
  5. import numpy as np
  6. from annotations import NAB_Annotations, CUB_Annotations
  7. from dataset import Dataset
  8. from dataset.utils import reveal_parts, uniform_parts, \
  9. random_select, \
  10. visible_part_locs, visible_crops
  11. import matplotlib.pyplot as plt
  12. def init_logger(args):
  13. fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
  14. logging.basicConfig(
  15. format=fmt,
  16. level=getattr(logging, args.loglevel.upper(), logging.DEBUG),
  17. filename=args.logfile or None,
  18. filemode="w")
  19. def main(args):
  20. init_logger(args)
  21. annotation_cls = dict(
  22. nab=NAB_Annotations,
  23. cub=CUB_Annotations)
  24. logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
  25. annot = annotation_cls.get(args.dataset.lower())(args.data)
  26. uuids = getattr(annot, "{}_uuids".format(args.subset.lower()))
  27. data = Dataset(
  28. uuids=uuids, annotations=annot,
  29. uniform_parts=args.uniform_parts,
  30. crop_to_bb=args.crop_to_bb,
  31. crop_uniform=args.crop_uniform,
  32. rnd_select=args.rnd,
  33. ratio=args.ratio,
  34. seed=args.seed
  35. )
  36. n_images = len(data)
  37. logging.info("Found {} images in the {} subset".format(n_images, args.subset))
  38. for i in range(n_images):
  39. if i + 1 <= args.start: continue
  40. im, parts, label = data[i]
  41. idxs, xy = visible_part_locs(parts)
  42. part_crops = visible_crops(im, parts, ratio=args.ratio)
  43. logging.debug(label)
  44. logging.debug(idxs)
  45. logging.debug(xy)
  46. fig1 = plt.figure(figsize=(16,9))
  47. ax = fig1.add_subplot(2,1,1)
  48. ax.imshow(im)
  49. ax.set_title("Visible Parts")
  50. ax.scatter(*xy, marker="x", c=idxs)
  51. ax.axis("off")
  52. ax = fig1.add_subplot(2,1,2)
  53. ax.set_title("{}selected parts".format("randomly " if args.rnd else ""))
  54. ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
  55. ax.scatter(*xy, marker="x", c=idxs)
  56. ax.axis("off")
  57. fig = plt.figure(figsize=(16,9))
  58. n_crops = part_crops.shape[0]
  59. rows = int(np.ceil(np.sqrt(n_crops)))
  60. cols = int(np.ceil(n_crops / rows))
  61. for j, crop in enumerate(part_crops, 1):
  62. ax = fig.add_subplot(rows, cols, j)
  63. ax.imshow(crop)
  64. ax.axis("off")
  65. middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
  66. ax.scatter(middle_w, middle_h, marker="x")
  67. plt.show()
  68. plt.close()
  69. if i+1 >= args.start + args.n_images: break
  70. parser = ArgumentParser()
  71. parser.add_argument("data",
  72. help="Folder containing the dataset with images and annotation files",
  73. type=str)
  74. parser.add_argument("--dataset",
  75. help="Possible datasets: NAB, CUB",
  76. choices=["cub", "nab"],
  77. default="nab", type=str)
  78. parser.add_argument("--subset",
  79. help="Possible subsets: train, test",
  80. choices=["train", "test"],
  81. default="train", type=str)
  82. parser.add_argument("--start", "-s",
  83. help="Image id to start with",
  84. type=int, default=0)
  85. parser.add_argument("--n_images", "-n",
  86. help="Number of images to display",
  87. type=int, default=10)
  88. parser.add_argument("--ratio",
  89. help="Part extraction ratio",
  90. type=float, default=.2)
  91. parser.add_argument("--rnd",
  92. help="select random subset of present parts",
  93. action="store_true")
  94. parser.add_argument("--uniform_parts", "-u",
  95. help="Do not use GT parts, but sample parts uniformly from the image",
  96. action="store_true")
  97. parser.add_argument("--crop_to_bb",
  98. help="Crop image to the bounding box",
  99. action="store_true")
  100. parser.add_argument("--crop_uniform",
  101. help="Try to extend the bounding box to same height and width",
  102. action="store_true")
  103. parser.add_argument(
  104. '--logfile', type=str, default='',
  105. help='File for logging output')
  106. parser.add_argument(
  107. '--loglevel', type=str, default='INFO',
  108. help='logging level. see logging module for more information')
  109. parser.add_argument(
  110. '--seed', type=int, default=12311123,
  111. help='random seed')
  112. main(parser.parse_args())