display.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import sys
  4. sys.path.insert(0, "..")
  5. import logging
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from argparse import ArgumentParser
  9. from nabirds.annotations import AnnotationType
  10. from utils import parser, plot_crops
  11. def main(args):
  12. annotation_cls = AnnotationType.get(args.dataset).value
  13. logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
  14. annot = annotation_cls(args.data, args.parts, args.feature_model)
  15. kwargs = {}
  16. if annot.info is None:
  17. # features = args.features[0 if args.subset == "train" else 1]
  18. kwargs = dict(
  19. part_rescale_size=args.rescale_size,
  20. # features=features,
  21. uniform_parts=args.uniform_parts,
  22. ratio=args.ratio,
  23. )
  24. data = annot.new_dataset(
  25. args.subset,
  26. crop_to_bb=args.crop_to_bb,
  27. crop_uniform=args.crop_uniform,
  28. parts_in_bb=args.parts_in_bb,
  29. rnd_select=args.rnd,
  30. seed=args.seed,
  31. **kwargs
  32. )
  33. logging.info("Loaded {} {} images".format(len(data), args.subset))
  34. start = max(args.start, 0)
  35. n_images = min(args.n_images, len(data) - start)
  36. for i in range(start, max(start, start + n_images)):
  37. im, parts, label = data[i]
  38. fig1, axs = plt.subplots(2, 1, figsize=(16,9))
  39. axs[0].axis("off")
  40. axs[0].set_title("Visible Parts")
  41. axs[0].imshow(im)
  42. if not args.crop_to_bb:
  43. data.plot_bounding_box(i, axs[0])
  44. parts.plot(im=im, ax=axs[0], ratio=data.ratio)
  45. axs[1].axis("off")
  46. axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
  47. axs[1].imshow(parts.reveal(im, ratio=data.ratio))
  48. if data.uniform_parts:
  49. crop_names = None
  50. else:
  51. crop_names = list(data._annot.part_names.values())
  52. part_crops = parts.visible_crops(im, ratio=data.ratio)
  53. if args.rnd:
  54. parts.invert_selection()
  55. action_crops = parts.visible_crops(im, ratio=data.ratio)
  56. plot_crops(part_crops, "Selected parts", names=crop_names)
  57. if args.rnd:
  58. plot_crops(action_crops, "Actions", names=crop_names)
  59. plt.show()
  60. plt.close()
  61. main(parser.parse_args())