123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- #!/usr/bin/env python
- if __name__ != '__main__': raise Exception("Do not import me!")
- from argparse import ArgumentParser
- import logging
- import numpy as np
- from annotations import NAB_Annotations, CUB_Annotations
- from dataset import Dataset, reveal_parts, \
- visible_part_locs, visible_crops, \
- uniform_part_locs, crops
- import matplotlib.pyplot as plt
- def init_logger(args):
- fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
- logging.basicConfig(
- format=fmt,
- level=getattr(logging, args.loglevel.upper(), logging.DEBUG),
- filename=args.logfile or None,
- filemode="w")
- def main(args):
- init_logger(args)
- annotation_cls = dict(
- nab=NAB_Annotations,
- cub=CUB_Annotations)
- logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
- annot = annotation_cls.get(args.dataset.lower())(args.data)
- uuids = getattr(annot, "{}_uuids".format(args.subset.lower()))
- data = Dataset(uuids, annot)
- n_images = len(data)
- logging.info("Found {} images in the {} subset".format(n_images, args.subset))
- for i in range(n_images):
- if i + 1 <= args.start: continue
- im, parts, label = data[i]
- if args.uniform_parts:
- idxs, xy = uniform_part_locs(im, ratio=args.ratio)
- else:
- idxs, xy = visible_part_locs(parts)
- logging.debug(label)
- logging.debug(idxs)
- fig1 = plt.figure(figsize=(16,9))
- ax = fig1.add_subplot(2,1,1)
- ax.imshow(im)
- ax.scatter(*xy, marker="x", c=idxs)
- ax = fig1.add_subplot(2,1,2)
- ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
- ax.scatter(*xy, marker="x", c=idxs)
- fig2 = plt.figure(figsize=(16,9))
- if args.uniform_parts:
- part_crops = crops(im, xy, ratio=args.ratio)
- else:
- part_crops = visible_crops(im, parts, ratio=args.ratio)
- n_crops = len(part_crops)
- rows = int(np.ceil(np.sqrt(n_crops)))
- cols = int(np.ceil(n_crops / rows))
- for j, crop in enumerate(part_crops, 1):
- ax = fig2.add_subplot(rows, cols, j)
- ax.imshow(crop)
- middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
- ax.scatter(middle_w, middle_h, marker="x")
- plt.show()
- plt.close(fig1)
- plt.close(fig2)
- if i+1 >= args.start + args.n_images: break
- parser = ArgumentParser()
- parser.add_argument("data",
- help="Folder containing the dataset with images and annotation files",
- type=str)
- parser.add_argument("--dataset",
- help="Possible datasets: NAB, CUB",
- choices=["cub", "nab"],
- default="nab", type=str)
- parser.add_argument("--subset",
- help="Possible subsets: train, test",
- choices=["train", "test"],
- default="train", type=str)
- parser.add_argument("--start", "-s",
- help="Image id to start with",
- type=int, default=0)
- parser.add_argument("--n_images", "-n",
- help="Number of images to display",
- type=int, default=10)
- parser.add_argument("--ratio",
- help="Part extraction ratio",
- type=float, default=.2)
- parser.add_argument("--uniform_parts", "-u",
- help="Do not use GT parts, but sample parts uniformly from the image",
- action="store_true")
- parser.add_argument(
- '--logfile', type=str, default='',
- help='File for logging output')
- parser.add_argument(
- '--loglevel', type=str, default='INFO',
- help='logging level. see logging module for more information')
- main(parser.parse_args())
|