#!/usr/bin/env python if __name__ != '__main__': raise Exception("Do not import me!") from argparse import ArgumentParser import logging import numpy as np from annotations import NAB_Annotations, CUB_Annotations from dataset import Dataset from dataset.utils import reveal_parts, uniform_parts, \ random_select, \ visible_part_locs, visible_crops import matplotlib.pyplot as plt def init_logger(args): fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s" logging.basicConfig( format=fmt, level=getattr(logging, args.loglevel.upper(), logging.DEBUG), filename=args.logfile or None, filemode="w") def main(args): init_logger(args) annotation_cls = dict( nab=NAB_Annotations, cub=CUB_Annotations) logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data)) annot = annotation_cls.get(args.dataset.lower())(args.data) uuids = getattr(annot, "{}_uuids".format(args.subset.lower())) data = Dataset(uuids, annot) n_images = len(data) logging.info("Found {} images in the {} subset".format(n_images, args.subset)) for i in range(n_images): if i + 1 <= args.start: continue im, parts, label = data[i] if args.uniform_parts: parts = uniform_parts(im, ratio=args.ratio) idxs, xy = visible_part_locs(parts) part_crops = visible_crops(im, parts, ratio=args.ratio) logging.debug(label) logging.debug(idxs) logging.debug(xy) fig1 = plt.figure(figsize=(16,9)) ax = fig1.add_subplot(2,1,1) ax.imshow(im) ax.set_title("Visible Parts") ax.scatter(*xy, marker="x", c=idxs) ax.axis("off") if args.rnd: idxs, xy, part_crops = random_select(idxs, xy, part_crops) ax = fig1.add_subplot(2,1,2) ax.set_title("{}selected parts".format("randomly " if args.rnd else "")) ax.imshow(reveal_parts(im, xy, ratio=args.ratio)) ax.scatter(*xy, marker="x", c=idxs) ax.axis("off") fig = plt.figure(figsize=(16,9)) n_crops = part_crops.shape[0] rows = int(np.ceil(np.sqrt(n_crops))) cols = int(np.ceil(n_crops / rows)) for j, crop in enumerate(part_crops, 1): ax = fig.add_subplot(rows, cols, j) ax.imshow(crop) ax.axis("off") middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2 ax.scatter(middle_w, middle_h, marker="x") plt.show() plt.close() if i+1 >= args.start + args.n_images: break parser = ArgumentParser() parser.add_argument("data", help="Folder containing the dataset with images and annotation files", type=str) parser.add_argument("--dataset", help="Possible datasets: NAB, CUB", choices=["cub", "nab"], default="nab", type=str) parser.add_argument("--subset", help="Possible subsets: train, test", choices=["train", "test"], default="train", type=str) parser.add_argument("--start", "-s", help="Image id to start with", type=int, default=0) parser.add_argument("--n_images", "-n", help="Number of images to display", type=int, default=10) parser.add_argument("--ratio", help="Part extraction ratio", type=float, default=.2) parser.add_argument("--rnd", help="select random subset of present parts", action="store_true") parser.add_argument("--uniform_parts", "-u", help="Do not use GT parts, but sample parts uniformly from the image", action="store_true") parser.add_argument( '--logfile', type=str, default='', help='File for logging output') parser.add_argument( '--loglevel', type=str, default='INFO', help='logging level. see logging module for more information') main(parser.parse_args())