PlotUtils.py 785 B

123456789101112131415161718192021
  1. import matplotlib.pyplot as plt
  2. from sklearn.metrics import roc_curve, auc
  3. def plot_roc_curve(test_labels: list, test_df: list, title: str, figsize=(8, 8), savefile = None):
  4. fpr, tpr, thresholds = roc_curve(test_labels, test_df)
  5. auc_score = auc(fpr, tpr)
  6. plt.figure(figsize=figsize)
  7. plt.plot(fpr, tpr, lw=1, label="ROC Curve")
  8. plt.plot([0, 1], [0, 1], color="lime", linestyle="--")
  9. plt.xlim([0.0, 1.0])
  10. plt.ylim([0.0, 1.05])
  11. plt.xlabel("FPR")
  12. plt.ylabel("TPR")
  13. plt.title(f"{title} (AUC = {auc_score})")
  14. plt.legend(loc="lower right")
  15. if savefile is not None:
  16. plt.savefig(f"{savefile}.png", bbox_inches="tight")
  17. plt.savefig(f"{savefile}.pdf", bbox_inches="tight")
  18. plt.show()
  19. return fpr, tpr, thresholds, auc_score