#!/usr/bin/env python
if __name__ != '__main__': raise Exception("Do not import me!")
import sys
sys.path.insert(0, "..")

import logging
import numpy as np
import matplotlib.pyplot as plt

from argparse import ArgumentParser

from cvdatasets.annotations import AnnotationType
from utils import parser, plot_crops

def main(args):
	assert args.dataset in AnnotationType, \
		"AnnotationType is not known: \"{}\"".format(args.dataset)

	annotation_cls = AnnotationType[args.dataset].value

	logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
	annot = annotation_cls(args.data, args.parts, args.feature_model)

	kwargs = {}
	if annot.info is None:
		# features = args.features[0 if args.subset == "train" else 1]
		kwargs = dict(
			part_rescale_size=args.rescale_size,
			# features=features,
			uniform_parts=args.uniform_parts,
			ratio=args.ratio,
		)

	data = annot.new_dataset(
		args.subset,

		center_cropped=not args.no_center_crop,
		crop_to_bb=args.crop_to_bb,
		crop_uniform=args.crop_uniform,

		parts_in_bb=args.parts_in_bb,

		rnd_select=args.rnd,
		seed=args.seed,

		**kwargs
	)

	logging.info("Loaded {} {} images".format(len(data), args.subset))

	start = max(args.start, 0)
	n_images = min(args.n_images, len(data) - start)
	idxs = range(start, max(start, start + n_images))

	for i in idxs:
		im, parts, label = data[i]

		fig1, axs = plt.subplots(2, 1, figsize=(16,9))

		axs[0].axis("off")
		axs[0].set_title("Visible Parts")
		axs[0].imshow(im)
		if not args.crop_to_bb and not args.no_bboxes:
			data.plot_bounding_box(i, axs[0])
		parts.plot(im=im, ax=axs[0], ratio=data.ratio)

		axs[1].axis("off")
		axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
		axs[1].imshow(parts.reveal(im, ratio=data.ratio))

		if data.uniform_parts:
			crop_names = None
		else:
			crop_names = list(data._annot.part_names.values())

		part_crops = parts.visible_crops(im, ratio=data.ratio)
		if args.rnd:
			parts.invert_selection()
			action_crops = parts.visible_crops(im, ratio=data.ratio)

		plot_crops(part_crops, "Selected parts", names=crop_names)

		if args.rnd:
			plot_crops(action_crops, "Actions", names=crop_names)

		plt.show()
		plt.close()

main(parser.parse_args())