example_nab.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. from nabirds import Dataset, NAB_Annotations
  4. from nabirds.dataset import visible_part_locs, visible_crops, reveal_parts
  5. import matplotlib.pyplot as plt
  6. annot = NAB_Annotations("/home/korsch1/korsch/datasets/birds/nabirds")
  7. print(annot.labels.shape)
  8. data = Dataset(annot.train_uuids, annot)
  9. start = 2000
  10. n_images = 5
  11. for i in range(len(data)):
  12. if i+1 <= start: continue
  13. im, parts, label = data[i]
  14. idxs, xy = visible_part_locs(parts)
  15. print(label)
  16. print(idxs)
  17. fig1 = plt.figure(figsize=(16,9))
  18. ax = fig1.add_subplot(2,1,1)
  19. ax.imshow(im)
  20. ax.scatter(*xy, marker="x", c=idxs)
  21. ax = fig1.add_subplot(2,1,2)
  22. ax.imshow(reveal_parts(im, xy))
  23. ax.scatter(*xy, marker="x", c=idxs)
  24. fig2 = plt.figure(figsize=(16,9))
  25. n_parts = parts.shape[0]
  26. for j, crop in enumerate(visible_crops(im, parts, .5), 1):
  27. ax = fig2.add_subplot(2, 6, j)
  28. ax.imshow(crop)
  29. middle = crop.shape[0] / 2
  30. ax.scatter(middle, middle, marker="x")
  31. plt.show()
  32. plt.close(fig1)
  33. plt.close(fig2)
  34. if i+1 >= start + n_images: break