PlotUtils.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) 2023 Felix Kleinsteuber and Computer Vision Group, Friedrich Schiller University Jena
  2. # This file defines helper functions for plotting.
  3. import matplotlib.pyplot as plt
  4. from sklearn.metrics import roc_curve, auc
  5. def plot_roc_curve(test_labels: list, test_df: list, title: str, figsize=(8, 8), savefile = None, show: bool = True):
  6. """Plots the roc curve of a classifier.
  7. Args:
  8. test_labels (list): Labels for the test examples.
  9. test_df (list): Decision function values for the test examples.
  10. title (str): Title of the plot.
  11. figsize (tuple, optional): Size of the plot. Defaults to (8, 8).
  12. savefile (_type_, optional): Output file without ending. Will be saved as pdf and png. If None, the plot is not saved. Defaults to None.
  13. show (bool, optional): If False, do not show the plot. Defaults to True.
  14. Returns:
  15. fpr (list of float), tpr (list of float), thresholds (list of float), auc_score (float): Points on roc curves, their thresholds, and the area under ROC curve.
  16. """
  17. fpr, tpr, thresholds = roc_curve(test_labels, test_df)
  18. auc_score = auc(fpr, tpr)
  19. if not show:
  20. plt.ioff()
  21. plt.figure(figsize=figsize)
  22. plt.plot(fpr, tpr, lw=1)
  23. plt.fill_between(fpr, tpr, label=f"AUC = {auc_score:.4f}", alpha=0.5)
  24. plt.plot([0, 1], [0, 1], color="gray", linestyle="dotted")
  25. plt.xlim([0.0, 1.0])
  26. plt.ylim([0.0, 1.0])
  27. plt.xlabel("FPR")
  28. plt.ylabel("TPR")
  29. plt.title(f"{title}")
  30. plt.legend(loc="lower right")
  31. if savefile is not None:
  32. plt.savefig(f"{savefile}.png", bbox_inches="tight")
  33. plt.savefig(f"{savefile}.pdf", bbox_inches="tight")
  34. if show:
  35. plt.show()
  36. return fpr, tpr, thresholds, auc_score
  37. def get_percentiles(fpr, tpr, thresholds, percentiles=[0.9, 0.95, 0.98, 0.99], verbose = True):
  38. """Returns the maximum possible TNR (elimination rate) for given minimum TPR.
  39. Args:
  40. fpr (list of float): FPR values from ROC curve.
  41. tpr (list of float): TPR values from ROC curve.
  42. thresholds (list of float): Thresholds from ROC curve.
  43. percentiles (list of float, optional): List of minimum TPR values to use as input. Defaults to [0.9, 0.95, 0.98, 0.99].
  44. verbose (bool, optional): If True, print the results. Defaults to True.
  45. Returns:
  46. list of float: TNR values aka elimination rates.
  47. """
  48. assert percentiles == sorted(percentiles)
  49. tnrs = []
  50. for percentile in percentiles:
  51. for i, tp in enumerate(tpr):
  52. if tp >= percentile:
  53. tnrs.append(1 - fpr[i]) # append tnr
  54. if verbose:
  55. print(f"{percentile} percentile : TPR = {tp:.4f}, FPR = {fpr[i]:.4f} <-> TNR = {(1 - fpr[i]):.4f} @ thresh {thresholds[i]}")
  56. break
  57. return tnrs