#!/usr/bin/env python if __name__ != '__main__': raise Exception("Do not import me!") import sys sys.path.insert(0, "..") try: from yaml import CLoader as Loader, CDumper as Dumper except ImportError: from yaml import Loader, Dumper import yaml import logging import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from argparse import ArgumentParser from nabirds import CUB_Annotations def init_logger(args): fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s" logging.basicConfig( format=fmt, level=getattr(logging, args.loglevel.upper(), logging.DEBUG), filename=args.logfile or None, filemode="w") def plot_crops(crops, title, scatter_mid=False, names=None): n_crops = crops.shape[0] rows = int(np.ceil(np.sqrt(n_crops))) cols = int(np.ceil(n_crops / rows)) fig, axs = plt.subplots(rows, cols, figsize=(16,9)) fig.suptitle(title, fontsize=16) for i, crop in enumerate(crops): ax = axs[np.unravel_index(i, axs.shape)] if names is not None: ax.set_title(names[i]) ax.imshow(crop) ax.axis("off") if scatter_mid: middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2 ax.scatter(middle_w, middle_h, marker="x") def main(args): init_logger(args) annot = CUB_Annotations( args.info, args.parts, args.feature_model) logging.info("Loaded data from \"{}\"".format(annot.root)) uuids = getattr(annot, "{}_uuids".format(args.subset)) data = annot.new_dataset( args.subset, crop_to_bb=args.crop_to_bb, crop_uniform=args.crop_uniform, parts_in_bb=args.parts_in_bb, rnd_select=args.rnd, seed=args.seed ) logging.info("Loaded {} {} images".format(len(data), args.subset)) start = max(args.start, 0) n_images = min(args.n_images, len(data) - start) for i in range(start, max(start, start + n_images)): im, parts, label = data[i] fig1, axs = plt.subplots(2, 1, figsize=(16,9)) axs[0].axis("off") axs[0].set_title("Visible Parts") axs[0].imshow(im) if not args.crop_to_bb: data.plot_bounding_box(i, axs[0]) parts.plot(im=im, ax=axs[0], ratio=data.ratio) axs[1].axis("off") axs[1].set_title("{}selected parts".format("randomly " if args.rnd else "")) axs[1].imshow(parts.reveal(im, ratio=data.ratio)) if data.uniform_parts: crop_names = None else: crop_names = list(data._annot.part_names.values()) part_crops = parts.visible_crops(im, ratio=data.ratio) if args.rnd: parts.invert_selection() action_crops = parts.visible_crops(im, ratio=data.ratio) plot_crops(part_crops, "Selected parts", names=crop_names) if args.rnd: plot_crops(action_crops, "Actions", names=crop_names) plt.show() plt.close() parser = ArgumentParser() parser.add_argument("info") parser.add_argument("--parts", "-p", choices=["GT", "GT2", "NAC", "UNI", "L1_pred", "L1_full"] ) parser.add_argument("--feature_model", "-fm", choices=["inception", "inception_tf", "resnet"] ) parser.add_argument("--subset", "-sub", help="Possible subsets: train, test", choices=["train", "test"], default="train", type=str) parser.add_argument("--start", "-s", help="Image id to start with", type=int, default=0) parser.add_argument("--n_images", "-n", help="Number of images to display", type=int, default=10) parser.add_argument("--rnd", help="select random subset of present parts", action="store_true") parser.add_argument("--crop_to_bb", help="Crop image to the bounding box", action="store_true") parser.add_argument("--crop_uniform", help="Try to extend the bounding box to same height and width", action="store_true") parser.add_argument("--parts_in_bb", help="Only display parts, that are inside the bounding box", action="store_true") parser.add_argument( '--logfile', type=str, default='', help='File for logging output') parser.add_argument( '--loglevel', type=str, default='INFO', help='logging level. see logging module for more information') parser.add_argument( '--seed', type=int, default=12311123, help='random seed') main(parser.parse_args())