#!/usr/bin/env python if __name__ != '__main__': raise Exception("Do not import me!") import sys sys.path.insert(0, "..") import logging import numpy as np import matplotlib.pyplot as plt from argparse import ArgumentParser from nabirds import NAB_Annotations, CUB_Annotations def init_logger(args): fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s" logging.basicConfig( format=fmt, level=getattr(logging, args.loglevel.upper(), logging.DEBUG), filename=args.logfile or None, filemode="w") def plot_crops(crops, title, scatter_mid=False, names=None): n_crops = crops.shape[0] rows = int(np.ceil(np.sqrt(n_crops))) cols = int(np.ceil(n_crops / rows)) fig, axs = plt.subplots(rows, cols, figsize=(16,9)) fig.suptitle(title, fontsize=16) for i, crop in enumerate(crops): ax = axs[np.unravel_index(i, axs.shape)] if names is not None: ax.set_title(names[i]) ax.imshow(crop) ax.axis("off") if scatter_mid: middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2 ax.scatter(middle_w, middle_h, marker="x") def main(args): init_logger(args) annotation_cls = dict( nab=NAB_Annotations, cub=CUB_Annotations) logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data)) annot = annotation_cls.get(args.dataset.lower())( args.data, args.parts, args.feature_model) logging.info("Loaded data from \"{}\"".format(annot.root)) subset = args.subset uuids = getattr(annot, "{}_uuids".format(subset)) data = annot.new_dataset( subset, crop_to_bb=args.crop_to_bb, crop_uniform=args.crop_uniform, parts_in_bb=args.parts_in_bb, rnd_select=args.rnd, seed=args.seed ) logging.info("Loaded {} {} images".format(len(data), subset)) start = max(args.start, 0) n_images = min(args.n_images, len(data) - start) for i in range(start, max(start, start + n_images)): im, parts, label = data[i] fig1, axs = plt.subplots(2, 1, figsize=(16,9)) axs[0].axis("off") axs[0].set_title("Visible Parts") axs[0].imshow(im) if not args.crop_to_bb: data.plot_bounding_box(i, axs[0]) parts.plot(im=im, ax=axs[0], ratio=data.ratio) axs[1].axis("off") axs[1].set_title("{}selected parts".format("randomly " if args.rnd else "")) axs[1].imshow(parts.reveal(im, ratio=data.ratio)) if data.uniform_parts: crop_names = None else: crop_names = list(data._annot.part_names.values()) part_crops = parts.visible_crops(im, ratio=data.ratio) if args.rnd: parts.invert_selection() action_crops = parts.visible_crops(im, ratio=data.ratio) plot_crops(part_crops, "Selected parts", names=crop_names) if args.rnd: plot_crops(action_crops, "Actions", names=crop_names) plt.show() plt.close() parser = ArgumentParser() parser.add_argument("data", help="Folder containing the dataset with images and annotation files or dataset info file", type=str) parser.add_argument("--dataset", help="Possible datasets: NAB, CUB", choices=["cub", "nab"], default="cub", type=str ) parser.add_argument("--parts", "-p", choices=["GT", "GT2", "NAC", "UNI", "L1_pred", "L1_full"] ) parser.add_argument("--feature_model", "-fm", choices=["inception", "inception_tf", "resnet"] ) parser.add_argument("--subset", "-sub", help="Possible subsets: train, test", choices=["train", "test"], default="train", type=str) parser.add_argument("--start", "-s", help="Image id to start with", type=int, default=0) parser.add_argument("--n_images", "-n", help="Number of images to display", type=int, default=10) parser.add_argument("--rnd", help="select random subset of present parts", action="store_true") parser.add_argument("--crop_to_bb", help="Crop image to the bounding box", action="store_true") parser.add_argument("--crop_uniform", help="Try to extend the bounding box to same height and width", action="store_true") parser.add_argument("--parts_in_bb", help="Only display parts, that are inside the bounding box", action="store_true") parser.add_argument( '--logfile', type=str, default='', help='File for logging output') parser.add_argument( '--loglevel', type=str, default='INFO', help='logging level. see logging module for more information') parser.add_argument( '--seed', type=int, default=12311123, help='random seed') main(parser.parse_args())