#!/usr/bin/env python if __name__ != '__main__': raise Exception("Do not import me!") import sys sys.path.insert(0, "..") import logging import numpy as np import matplotlib.pyplot as plt from argparse import ArgumentParser from nabirds import NAB_Annotations, CUB_Annotations def init_logger(args): fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s" logging.basicConfig( format=fmt, level=getattr(logging, args.loglevel.upper(), logging.DEBUG), filename=args.logfile or None, filemode="w") def plot_crops(crops, title, scatter_mid=False, names=None): n_crops = crops.shape[0] rows = int(np.ceil(np.sqrt(n_crops))) cols = int(np.ceil(n_crops / rows)) fig, axs = plt.subplots(rows, cols, figsize=(16,9)) fig.suptitle(title, fontsize=16) for i, crop in enumerate(crops): ax = axs[np.unravel_index(i, axs.shape)] if names is not None: ax.set_title(names[i]) ax.imshow(crop) ax.axis("off") if scatter_mid: middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2 ax.scatter(middle_w, middle_h, marker="x") def main(args): init_logger(args) annotation_cls = dict( nab=NAB_Annotations, cub=CUB_Annotations) logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data)) annot = annotation_cls.get(args.dataset.lower())(args.data) subset = args.subset uuids = getattr(annot, "{}_uuids".format(subset)) features = args.features[0 if subset == "train" else 1] data = annot.new_dataset( subset, part_rescale_size=args.rescale_size, # features=features, uniform_parts=args.uniform_parts, crop_to_bb=args.crop_to_bb, crop_uniform=args.crop_uniform, parts_in_bb=args.parts_in_bb, rnd_select=args.rnd, ratio=args.ratio, seed=args.seed ) logging.info("Loaded {} {} images".format(len(data), subset)) start = max(args.start, 0) n_images = min(args.n_images, len(data) - start) for i in range(start, max(start, start + n_images)): im, parts, label = data[i] fig1, axs = plt.subplots(2, 1, figsize=(16,9)) axs[0].axis("off") axs[0].set_title("Visible Parts") axs[0].imshow(im) if not args.crop_to_bb: data.plot_bounding_box(i, axs[0]) parts.plot(im=im, ax=axs[0], ratio=data.ratio) axs[1].axis("off") axs[1].set_title("{}selected parts".format("randomly " if args.rnd else "")) axs[1].imshow(parts.reveal(im, ratio=data.ratio)) if data.uniform_parts: crop_names = None else: crop_names = list(data._annot.part_names.values()) part_crops = parts.visible_crops(im, ratio=data.ratio) if args.rnd: parts.invert_selection() action_crops = parts.visible_crops(im, ratio=data.ratio) plot_crops(part_crops, "Selected parts", names=crop_names) if args.rnd: plot_crops(action_crops, "Actions", names=crop_names) plt.show() plt.close() parser = ArgumentParser() parser.add_argument("data", help="Folder containing the dataset with images and annotation files", type=str) parser.add_argument("--dataset", help="Possible datasets: NAB, CUB", choices=["cub", "nab"], default="cub", type=str ) parser.add_argument("--subset", "-sub", help="Possible subsets: train, test", choices=["train", "test"], default="train", type=str) parser.add_argument("--start", "-s", help="Image id to start with", type=int, default=0) parser.add_argument("--n_images", "-n", help="Number of images to display", type=int, default=10) parser.add_argument("--features", help="pre-extracted train and test features", default=[None, None], nargs=2, type=str) parser.add_argument("--ratio", help="Part extraction ratio", type=float, default=.2) parser.add_argument("--rescale_size", help="rescales the part positions from this size to original image size", type=int, default=-1) parser.add_argument("--uniform_parts", "-u", help="Do not use GT parts, but sample parts uniformly from the image", action="store_true") parser.add_argument("--rnd", help="select random subset of present parts", action="store_true") parser.add_argument("--crop_to_bb", help="Crop image to the bounding box", action="store_true") parser.add_argument("--crop_uniform", help="Try to extend the bounding box to same height and width", action="store_true") parser.add_argument("--parts_in_bb", help="Only display parts, that are inside the bounding box", action="store_true") parser.add_argument( '--logfile', type=str, default='', help='File for logging output') parser.add_argument( '--loglevel', type=str, default='INFO', help='logging level. see logging module for more information') parser.add_argument( '--seed', type=int, default=12311123, help='random seed') main(parser.parse_args())