123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- #!/usr/bin/env python
- if __name__ != '__main__': raise Exception("Do not import me!")
- from nabirds import Dataset, NAB_Annotations
- from nabirds.dataset import visible_part_locs, visible_crops, reveal_parts
- import matplotlib.pyplot as plt
- annot = NAB_Annotations("/home/korsch1/korsch/datasets/birds/nabirds")
- print(annot.labels.shape)
- data = Dataset(annot.train_uuids, annot)
- start = 2000
- n_images = 5
- for i in range(len(data)):
- if i+1 <= start: continue
- im, parts, label = data[i]
- idxs, xy = visible_part_locs(parts)
- print(label)
- print(idxs)
- fig1 = plt.figure(figsize=(16,9))
- ax = fig1.add_subplot(2,1,1)
- ax.imshow(im)
- ax.scatter(*xy, marker="x", c=idxs)
- ax = fig1.add_subplot(2,1,2)
- ax.imshow(reveal_parts(im, xy))
- ax.scatter(*xy, marker="x", c=idxs)
- fig2 = plt.figure(figsize=(16,9))
- n_parts = parts.shape[0]
- for j, crop in enumerate(visible_crops(im, parts, .5), 1):
- ax = fig2.add_subplot(2, 6, j)
- ax.imshow(crop)
- middle = crop.shape[0] / 2
- ax.scatter(middle, middle, marker="x")
- plt.show()
- plt.close(fig1)
- plt.close(fig2)
- if i+1 >= start + n_images: break
|