#!/usr/bin/env python
if __name__ != '__main__': raise Exception("Do not import me!")

from nabirds import Dataset, NAB_Annotations
from nabirds.dataset import visible_part_locs, visible_crops, reveal_parts
import matplotlib.pyplot as plt

annot = NAB_Annotations("/home/korsch1/korsch/datasets/birds/nabirds")

print(annot.labels.shape)
data = Dataset(annot.train_uuids, annot)

start = 2000
n_images = 5

for i in range(len(data)):
	if i+1 <= start: continue

	im, parts, label = data[i]

	idxs, xy = visible_part_locs(parts)

	print(label)
	print(idxs)

	fig1 = plt.figure(figsize=(16,9))
	ax = fig1.add_subplot(2,1,1)

	ax.imshow(im)
	ax.scatter(*xy, marker="x", c=idxs)

	ax = fig1.add_subplot(2,1,2)
	ax.imshow(reveal_parts(im, xy))
	ax.scatter(*xy, marker="x", c=idxs)

	fig2 = plt.figure(figsize=(16,9))
	n_parts = parts.shape[0]

	for j, crop in enumerate(visible_crops(im, parts, .5), 1):
		ax = fig2.add_subplot(2, 6, j)
		ax.imshow(crop)

		middle = crop.shape[0] / 2
		ax.scatter(middle, middle, marker="x")

	plt.show()
	plt.close(fig1)
	plt.close(fig2)

	if i+1 >= start + n_images: break