#!/usr/bin/env python
if __name__ != '__main__': raise Exception("Do not import me!")
import sys
sys.path.insert(0, "..")

"""
	Possible calls:

	./display.sh /home/korsch1/korsch/datasets/birds/cub200_11 --dataset cub -s600 -n5 --features /home/korsch1/korsch/datasets/birds/features/{train,val}_16parts_gt.npz --ratio 0.31
	> displays GT parts of CUB200

	./display.sh /home/korsch1/korsch/datasets/birds/NAC/2017-bilinear/ --dataset cub -s600 -n5 --features /home/korsch1/korsch/datasets/birds/features/{train,val}_16parts_gt.npz --ratio 0.31 --rescale_size 227
	> displays NAC parts of CUB200

"""

from argparse import ArgumentParser
import logging
import numpy as np

from annotations import NAB_Annotations, CUB_Annotations
from dataset import Dataset
from dataset.utils import reveal_parts, uniform_parts, \
	random_select, \
	visible_part_locs, visible_crops


import matplotlib.pyplot as plt

def init_logger(args):
	fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
	logging.basicConfig(
		format=fmt,
		level=getattr(logging, args.loglevel.upper(), logging.DEBUG),
		filename=args.logfile or None,
		filemode="w")

def plot_crops(crops, title, scatter_mid=False, names=None):

	fig = plt.figure(figsize=(16,9))
	fig.suptitle(title, fontsize=16)

	n_crops = crops.shape[0]
	rows = int(np.ceil(np.sqrt(n_crops)))
	cols = int(np.ceil(n_crops / rows))

	for j, crop in enumerate(crops, 1):
		ax = fig.add_subplot(rows, cols, j)
		if names is not None:
			ax.set_title(names[j-1])
		ax.imshow(crop)
		ax.axis("off")
		if scatter_mid:
			middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
			ax.scatter(middle_w, middle_h, marker="x")



def main(args):
	init_logger(args)

	annotation_cls = dict(
		nab=NAB_Annotations,
		cub=CUB_Annotations)

	logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
	annot = annotation_cls.get(args.dataset.lower())(args.data)

	subset = args.subset.lower()

	uuids = getattr(annot, "{}_uuids".format(subset))
	features = args.features[0 if subset == "train" else 1]

	data = Dataset(
		uuids=uuids, annotations=annot,
		part_rescale_size=args.rescale_size,
		features=features,

		uniform_parts=args.uniform_parts,

		crop_to_bb=args.crop_to_bb,
		crop_uniform=args.crop_uniform,

		parts_in_bb=args.parts_in_bb,

		rnd_select=args.rnd,
		ratio=args.ratio,
		seed=args.seed

	)
	n_images = len(data)
	logging.info("Found {} images in the {} subset".format(n_images, subset))

	for i in range(n_images):
		if i + 1 <= args.start: continue
		im, parts, label = data[i]

		idxs, xy = visible_part_locs(parts)
		part_crops = visible_crops(im, parts, ratio=args.ratio)
		if args.rnd:
			selected = parts[:, -1].astype(bool)
			parts[selected, -1] = 0
			parts[np.logical_not(selected), -1] = 1
			action_crops = visible_crops(im, parts, ratio=args.ratio)

		logging.debug(label)
		logging.debug(idxs)
		logging.debug(xy)

		fig1 = plt.figure(figsize=(16,9))
		ax = fig1.add_subplot(2,1,1)
		ax.imshow(im)
		ax.set_title("Visible Parts")
		ax.scatter(*xy, marker="x", c=idxs)
		ax.axis("off")

		ax = fig1.add_subplot(2,1,2)
		ax.set_title("{}selected parts".format("randomly " if args.rnd else ""))
		ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
		# ax.scatter(*xy, marker="x", c=idxs)
		ax.axis("off")
		crop_names = list(data._annot.part_names.values())
		plot_crops(part_crops, "Selected parts", names=crop_names)

		if args.rnd:
			plot_crops(action_crops, "Actions")

		plt.show()
		plt.close()

		if i+1 >= args.start + args.n_images: break

parser = ArgumentParser()

parser.add_argument("data",
	help="Folder containing the dataset with images and annotation files",
	type=str)

parser.add_argument("--dataset",
	help="Possible datasets: NAB, CUB",
	choices=["cub", "nab"],
	default="nab", type=str)

parser.add_argument("--features",
	help="pre-extracted train and test features",
	default=[None, None],
	nargs=2, type=str)

parser.add_argument("--subset",
	help="Possible subsets: train, test",
	choices=["train", "test"],
	default="train", type=str)

parser.add_argument("--start", "-s",
	help="Image id to start with",
	type=int, default=0)

parser.add_argument("--n_images", "-n",
	help="Number of images to display",
	type=int, default=10)

parser.add_argument("--ratio",
	help="Part extraction ratio",
	type=float, default=.2)

parser.add_argument("--rescale_size",
	help="rescales the part positions from this size to original image size",
	type=int, default=-1)

parser.add_argument("--rnd",
	help="select random subset of present parts",
	action="store_true")

parser.add_argument("--uniform_parts", "-u",
	help="Do not use GT parts, but sample parts uniformly from the image",
	action="store_true")

parser.add_argument("--crop_to_bb",
	help="Crop image to the bounding box",
	action="store_true")

parser.add_argument("--crop_uniform",
	help="Try to extend the bounding box to same height and width",
	action="store_true")

parser.add_argument("--parts_in_bb",
	help="Only display parts, that are inside the bounding box",
	action="store_true")



parser.add_argument(
	'--logfile', type=str, default='',
	help='File for logging output')

parser.add_argument(
	'--loglevel', type=str, default='INFO',
	help='logging level. see logging module for more information')

parser.add_argument(
	'--seed', type=int, default=12311123,
	help='random seed')

main(parser.parse_args())