#!/usr/bin/env python if __name__ != '__main__': raise Exception("Do not import me!") from nabirds import Dataset, NAB_Annotations from nabirds.dataset import visible_part_locs, visible_crops, reveal_parts import matplotlib.pyplot as plt annot = NAB_Annotations("/home/korsch1/korsch/datasets/birds/nabirds") print(annot.labels.shape) data = Dataset(annot.train_uuids, annot) start = 2000 n_images = 5 for i in range(len(data)): if i+1 <= start: continue im, parts, label = data[i] idxs, xy = visible_part_locs(parts) print(label) print(idxs) fig1 = plt.figure(figsize=(16,9)) ax = fig1.add_subplot(2,1,1) ax.imshow(im) ax.scatter(*xy, marker="x", c=idxs) ax = fig1.add_subplot(2,1,2) ax.imshow(reveal_parts(im, xy)) ax.scatter(*xy, marker="x", c=idxs) fig2 = plt.figure(figsize=(16,9)) n_parts = parts.shape[0] for j, crop in enumerate(visible_crops(im, parts, .5), 1): ax = fig2.add_subplot(2, 6, j) ax.imshow(crop) middle = crop.shape[0] / 2 ax.scatter(middle, middle, marker="x") plt.show() plt.close(fig1) plt.close(fig2) if i+1 >= start + n_images: break