#!/usr/bin/env python
if __name__ != '__main__': raise Exception("Do not import me!")

from nabirds import Dataset, CUB_Annotations
from nabirds.dataset import visible_part_locs, visible_crops
import matplotlib.pyplot as plt

annot = CUB_Annotations(root="/home/korsch1/korsch/datasets/birds/cub200_11")

print(annot.labels.shape)
data = Dataset(annot.train_uuids, annot)

for i, (im, parts, label) in enumerate(data, 1):
	if i <= 15: continue

	idxs, (xs, ys) = visible_part_locs(parts)

	print(label)
	print(idxs)

	fig1 = plt.figure(figsize=(16,9))
	ax = fig1.add_subplot(111)

	ax.imshow(im)
	ax.scatter(xs, ys, marker="x", c=idxs)

	fig2 = plt.figure(figsize=(16,9))
	n_parts = parts.shape[0]

	for j, crop in enumerate(visible_crops(im, parts, .5), 1):
		ax = fig2.add_subplot(3, 5, j)
		ax.imshow(crop)

		middle = crop.shape[0] / 2
		ax.scatter(middle, middle, marker="x")

	plt.show()
	plt.close(fig1)
	plt.close(fig2)

	if i >= 20: break