123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- #!/usr/bin/env python
- if __name__ != '__main__': raise Exception("Do not import me!")
- import sys
- sys.path.insert(0, "..")
- try:
- from yaml import CLoader as Loader, CDumper as Dumper
- except ImportError:
- from yaml import Loader, Dumper
- import yaml
- import logging
- import numpy as np
- import matplotlib.pyplot as plt
- from matplotlib.patches import Rectangle
- from argparse import ArgumentParser
- from nabirds import CUB_Annotations
- def init_logger(args):
- fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
- logging.basicConfig(
- format=fmt,
- level=getattr(logging, args.loglevel.upper(), logging.DEBUG),
- filename=args.logfile or None,
- filemode="w")
- def plot_crops(crops, title, scatter_mid=False, names=None):
- n_crops = crops.shape[0]
- rows = int(np.ceil(np.sqrt(n_crops)))
- cols = int(np.ceil(n_crops / rows))
- fig, axs = plt.subplots(rows, cols, figsize=(16,9))
- fig.suptitle(title, fontsize=16)
- for i, crop in enumerate(crops):
- ax = axs[np.unravel_index(i, axs.shape)]
- if names is not None:
- ax.set_title(names[i])
- ax.imshow(crop)
- ax.axis("off")
- if scatter_mid:
- middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
- ax.scatter(middle_w, middle_h, marker="x")
- def main(args):
- init_logger(args)
- annot = CUB_Annotations(
- args.info, args.parts, args.feature_model)
- logging.info("Loaded data from \"{}\"".format(annot.root))
- uuids = getattr(annot, "{}_uuids".format(args.subset))
- data = annot.new_dataset(
- args.subset,
- crop_to_bb=args.crop_to_bb,
- crop_uniform=args.crop_uniform,
- parts_in_bb=args.parts_in_bb,
- rnd_select=args.rnd,
- seed=args.seed
- )
- logging.info("Loaded {} {} images".format(len(data), args.subset))
- start = max(args.start, 0)
- n_images = min(args.n_images, len(data) - start)
- for i in range(start, max(start, start + n_images)):
- im, parts, label = data[i]
- fig1, axs = plt.subplots(2, 1, figsize=(16,9))
- axs[0].axis("off")
- axs[0].set_title("Visible Parts")
- axs[0].imshow(im)
- if not args.crop_to_bb:
- data.plot_bounding_box(i, axs[0])
- parts.plot(im=im, ax=axs[0], ratio=data.ratio)
- axs[1].axis("off")
- axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
- axs[1].imshow(parts.reveal(im, ratio=data.ratio))
- if data.uniform_parts:
- crop_names = None
- else:
- crop_names = list(data._annot.part_names.values())
- part_crops = parts.visible_crops(im, ratio=data.ratio)
- if args.rnd:
- parts.invert_selection()
- action_crops = parts.visible_crops(im, ratio=data.ratio)
- plot_crops(part_crops, "Selected parts", names=crop_names)
- if args.rnd:
- plot_crops(action_crops, "Actions", names=crop_names)
- plt.show()
- plt.close()
- parser = ArgumentParser()
- parser.add_argument("info")
- parser.add_argument("--parts", "-p",
- choices=["GT", "GT2", "NAC", "UNI", "L1_pred", "L1_full"]
- )
- parser.add_argument("--feature_model", "-fm",
- choices=["inception", "inception_tf", "resnet"]
- )
- parser.add_argument("--subset", "-sub",
- help="Possible subsets: train, test",
- choices=["train", "test"],
- default="train", type=str)
- parser.add_argument("--start", "-s",
- help="Image id to start with",
- type=int, default=0)
- parser.add_argument("--n_images", "-n",
- help="Number of images to display",
- type=int, default=10)
- parser.add_argument("--rnd",
- help="select random subset of present parts",
- action="store_true")
- parser.add_argument("--crop_to_bb",
- help="Crop image to the bounding box",
- action="store_true")
- parser.add_argument("--crop_uniform",
- help="Try to extend the bounding box to same height and width",
- action="store_true")
- parser.add_argument("--parts_in_bb",
- help="Only display parts, that are inside the bounding box",
- action="store_true")
- parser.add_argument(
- '--logfile', type=str, default='',
- help='File for logging output')
- parser.add_argument(
- '--loglevel', type=str, default='INFO',
- help='logging level. see logging module for more information')
- parser.add_argument(
- '--seed', type=int, default=12311123,
- help='random seed')
- main(parser.parse_args())
|