__init__.py 676 B

12345678910111213141516171819202122232425
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. def plot_crops(crops, title, scatter_mid=False, names=None):
  4. n_crops = len(crops)
  5. if n_crops == 0: return
  6. rows = int(np.ceil(np.sqrt(n_crops)))
  7. cols = int(np.ceil(n_crops / rows))
  8. fig, axs = plt.subplots(rows, cols, figsize=(16,9))
  9. fig.suptitle(title, fontsize=16)
  10. [axs[np.unravel_index(i, axs.shape)].axis("off") for i in range(cols*rows)]
  11. for i, crop in enumerate(crops):
  12. ax = axs[np.unravel_index(i, axs.shape)]
  13. if names is not None:
  14. ax.set_title(names[i])
  15. ax.imshow(crop)
  16. if scatter_mid:
  17. middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
  18. ax.scatter(middle_w, middle_h, marker="x")