display.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import sys
  4. sys.path.insert(0, "..")
  5. import logging
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from argparse import ArgumentParser
  9. from cvdatasets.annotation import AnnotationType
  10. from utils import parser, plot_crops
  11. def main(args):
  12. assert args.dataset in AnnotationType, \
  13. f"AnnotationType is not known: \"{args.dataset}\""
  14. annot = AnnotationType.new_annotation(args)
  15. # annotation_cls = AnnotationType[args.dataset].value
  16. # logging.info(f"Loading \"{args.dataset}\" annnotations from \"{args.data}\"")
  17. # annot = annotation_cls.new(args, )
  18. # annot = annotation_cls(root_or_infofile=args.data, parts=args.parts, load_strict=False)
  19. kwargs = {}
  20. if annot.info is None:
  21. # features = args.features[0 if args.subset == "train" else 1]
  22. kwargs = dict(
  23. part_rescale_size=args.rescale_size,
  24. # features=features,
  25. uniform_parts=args.uniform_parts,
  26. ratio=args.ratio,
  27. )
  28. data = annot.new_dataset(
  29. args.subset,
  30. center_cropped=not args.no_center_crop,
  31. crop_to_bb=args.crop_to_bb,
  32. crop_uniform=args.crop_uniform,
  33. parts_in_bb=args.parts_in_bb,
  34. rnd_select=args.rnd,
  35. seed=args.seed,
  36. **kwargs
  37. )
  38. logging.info(f"Loaded {len(data)} {args.subset} images")
  39. if args.only_class >= 0:
  40. mask = data.labels == args.only_class
  41. logging.info(f"Showing only {mask.sum()} images from class {args.only_class}")
  42. idxs = np.where(mask)[0]
  43. else:
  44. start = max(args.start, 0)
  45. n_images = min(args.n_images, len(data) - start)
  46. end = max(start, start + n_images)
  47. logging.info(f"Showing only images {start} - {end}")
  48. idxs = range(start, end)
  49. for i in idxs:
  50. im, parts, label = data[i]
  51. fig1, axs = plt.subplots(1, 1, figsize=(16,9))
  52. axs = [axs]
  53. axs[0].axis("off")
  54. axs[0].set_title("Visible Parts")
  55. axs[0].imshow(im)
  56. if not args.crop_to_bb and not args.no_bboxes:
  57. data.plot_bounding_box(i, axs[0])
  58. # axs[1].axis("off")
  59. # axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
  60. # axs[1].imshow(parts.reveal(im, ratio=data.ratio))
  61. if not args.no_parts:
  62. parts.plot(im=im, ax=axs[0], ratio=data.ratio, linewidth=3)
  63. if data.uniform_parts:
  64. crop_names = None
  65. else:
  66. crop_names = list(data._annot.part_names.values())
  67. part_crops = parts.visible_crops(im, ratio=data.ratio)
  68. if args.rnd:
  69. parts.invert_selection()
  70. action_crops = parts.visible_crops(im, ratio=data.ratio)
  71. plot_crops(part_crops, f"{args.parts}: Selected parts", names=crop_names)
  72. if args.rnd:
  73. plot_crops(action_crops, f"{args.parts}: Actions", names=crop_names)
  74. plt.show()
  75. plt.close()
  76. main(parser.parse_args())