display.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. from argparse import ArgumentParser
  4. import logging
  5. import numpy as np
  6. from nabirds import Dataset, NAB_Annotations, CUB_Annotations
  7. from nabirds.dataset import visible_part_locs, visible_crops, reveal_parts
  8. import matplotlib.pyplot as plt
  9. def init_logger(args):
  10. fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
  11. logging.basicConfig(
  12. format=fmt,
  13. level=getattr(logging, args.loglevel.upper(), logging.DEBUG),
  14. filename=args.logfile or None,
  15. filemode="w")
  16. def main(args):
  17. init_logger(args)
  18. annotation_cls = dict(
  19. nab=NAB_Annotations,
  20. cub=CUB_Annotations)
  21. logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
  22. annot = annotation_cls.get(args.dataset.lower())(args.data)
  23. uuids = getattr(annot, "{}_uuids".format(args.subset.lower()))
  24. data = Dataset(uuids, annot)
  25. n_images = len(data)
  26. logging.info("Found {} images in the {} subset".format(n_images, args.subset))
  27. for i in range(n_images):
  28. if i + 1 <= args.start: continue
  29. im, parts, label = data[i]
  30. idxs, xy = visible_part_locs(parts)
  31. logging.debug(label)
  32. logging.debug(idxs)
  33. fig1 = plt.figure(figsize=(16,9))
  34. ax = fig1.add_subplot(2,1,1)
  35. ax.imshow(im)
  36. ax.scatter(*xy, marker="x", c=idxs)
  37. ax = fig1.add_subplot(2,1,2)
  38. ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
  39. ax.scatter(*xy, marker="x", c=idxs)
  40. fig2 = plt.figure(figsize=(16,9))
  41. n_parts = parts.shape[0]
  42. rows, cols = (2,6) if args.dataset.lower() == "nab" else (3,5)
  43. for j, crop in enumerate(visible_crops(im, parts, ratio=args.ratio), 1):
  44. ax = fig2.add_subplot(rows, cols, j)
  45. ax.imshow(crop)
  46. middle = crop.shape[0] / 2
  47. ax.scatter(middle, middle, marker="x")
  48. plt.show()
  49. plt.close(fig1)
  50. plt.close(fig2)
  51. if i+1 >= args.start + args.n_images: break
  52. parser = ArgumentParser()
  53. parser.add_argument("data",
  54. help="Folder containing the dataset with images and annotation files",
  55. type=str)
  56. parser.add_argument("--dataset",
  57. help="Possible datasets: NAB, CUB",
  58. choices=["cub", "nab"],
  59. default="nab", type=str)
  60. parser.add_argument("--subset",
  61. help="Possible subsets: train, test",
  62. choices=["train", "test"],
  63. default="train", type=str)
  64. parser.add_argument("--start", "-s",
  65. help="Image id to start with",
  66. type=int, default=0)
  67. parser.add_argument("--n_images", "-n",
  68. help="Number of images to display",
  69. type=int, default=10)
  70. parser.add_argument("--ratio",
  71. help="Part extraction ratio",
  72. type=float, default=.2)
  73. parser.add_argument(
  74. '--logfile', type=str, default='',
  75. help='File for logging output')
  76. parser.add_argument(
  77. '--loglevel', type=str, default='INFO',
  78. help='logging level. see logging module for more information')
  79. main(parser.parse_args())