12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- #!/usr/bin/env python
- if __name__ != '__main__': raise Exception("Do not import me!")
- import sys
- sys.path.insert(0, "..")
- import logging
- import numpy as np
- import matplotlib.pyplot as plt
- from argparse import ArgumentParser
- from nabirds.annotations import AnnotationType
- from utils import parser, plot_crops
- def main(args):
- annotation_cls = AnnotationType.get(args.dataset).value
- logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
- annot = annotation_cls(args.data, args.parts, args.feature_model)
- kwargs = {}
- if annot.info is None:
- # features = args.features[0 if args.subset == "train" else 1]
- kwargs = dict(
- part_rescale_size=args.rescale_size,
- # features=features,
- uniform_parts=args.uniform_parts,
- ratio=args.ratio,
- )
- data = annot.new_dataset(
- args.subset,
- crop_to_bb=args.crop_to_bb,
- crop_uniform=args.crop_uniform,
- parts_in_bb=args.parts_in_bb,
- rnd_select=args.rnd,
- seed=args.seed,
- **kwargs
- )
- logging.info("Loaded {} {} images".format(len(data), args.subset))
- start = max(args.start, 0)
- n_images = min(args.n_images, len(data) - start)
- for i in range(start, max(start, start + n_images)):
- im, parts, label = data[i]
- fig1, axs = plt.subplots(2, 1, figsize=(16,9))
- axs[0].axis("off")
- axs[0].set_title("Visible Parts")
- axs[0].imshow(im)
- if not args.crop_to_bb:
- data.plot_bounding_box(i, axs[0])
- parts.plot(im=im, ax=axs[0], ratio=data.ratio)
- axs[1].axis("off")
- axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
- axs[1].imshow(parts.reveal(im, ratio=data.ratio))
- if data.uniform_parts:
- crop_names = None
- else:
- crop_names = list(data._annot.part_names.values())
- part_crops = parts.visible_crops(im, ratio=data.ratio)
- if args.rnd:
- parts.invert_selection()
- action_crops = parts.visible_crops(im, ratio=data.ratio)
- plot_crops(part_crops, "Selected parts", names=crop_names)
- if args.rnd:
- plot_crops(action_crops, "Actions", names=crop_names)
- plt.show()
- plt.close()
- main(parser.parse_args())
|