PlotUtils.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import matplotlib.pyplot as plt
  2. from sklearn.metrics import roc_curve, auc
  3. def plot_roc_curve(test_labels: list, test_df: list, title: str, figsize=(8, 8), savefile = None, show: bool = True):
  4. fpr, tpr, thresholds = roc_curve(test_labels, test_df)
  5. auc_score = auc(fpr, tpr)
  6. if not show:
  7. plt.ioff()
  8. plt.figure(figsize=figsize)
  9. plt.plot(fpr, tpr, lw=1)
  10. plt.fill_between(fpr, tpr, label=f"AUC = {auc_score:.4f}", alpha=0.5)
  11. plt.plot([0, 1], [0, 1], color="gray", linestyle="dotted")
  12. plt.xlim([0.0, 1.0])
  13. plt.ylim([0.0, 1.0])
  14. plt.xlabel("FPR")
  15. plt.ylabel("TPR")
  16. plt.title(f"{title}")
  17. plt.legend(loc="lower right")
  18. if savefile is not None:
  19. plt.savefig(f"{savefile}.png", bbox_inches="tight")
  20. plt.savefig(f"{savefile}.pdf", bbox_inches="tight")
  21. if show:
  22. plt.show()
  23. return fpr, tpr, thresholds, auc_score
  24. def get_percentiles(fpr, tpr, thresholds, percentiles=[0.9, 0.95, 0.98, 0.99], verbose = True):
  25. assert percentiles == sorted(percentiles)
  26. tnrs = []
  27. for percentile in percentiles:
  28. for i, tp in enumerate(tpr):
  29. if tp >= percentile:
  30. tnrs.append(1 - fpr[i]) # append tnr
  31. if verbose:
  32. print(f"{percentile} percentile : TPR = {tp:.4f}, FPR = {fpr[i]:.4f} <-> TNR = {(1 - fpr[i]):.4f} @ thresh {thresholds[i]}")
  33. break
  34. return tnrs