__init__.py 594 B

1234567891011121314151617181920212223
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. def plot_crops(crops, title, scatter_mid=False, names=None):
  4. n_crops = crops.shape[0]
  5. rows = int(np.ceil(np.sqrt(n_crops)))
  6. cols = int(np.ceil(n_crops / rows))
  7. fig, axs = plt.subplots(rows, cols, figsize=(16,9))
  8. fig.suptitle(title, fontsize=16)
  9. for i, crop in enumerate(crops):
  10. ax = axs[np.unravel_index(i, axs.shape)]
  11. if names is not None:
  12. ax.set_title(names[i])
  13. ax.imshow(crop)
  14. ax.axis("off")
  15. if scatter_mid:
  16. middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
  17. ax.scatter(middle_w, middle_h, marker="x")