roc_curve.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from __future__ import annotations
  2. import numbers
  3. from collections.abc import Sequence
  4. from typing import TYPE_CHECKING
  5. import wandb
  6. from wandb import util
  7. from wandb.plot.custom_chart import plot_table
  8. from wandb.plot.utils import test_missing, test_types
  9. if TYPE_CHECKING:
  10. from wandb.plot.custom_chart import CustomChart
  11. def roc_curve(
  12. y_true: Sequence[numbers.Number],
  13. y_probas: Sequence[Sequence[float]] | None = None,
  14. labels: list[str] | None = None,
  15. classes_to_plot: list[numbers.Number] | None = None,
  16. title: str = "ROC Curve",
  17. split_table: bool = False,
  18. ) -> CustomChart:
  19. """Constructs Receiver Operating Characteristic (ROC) curve chart.
  20. Args:
  21. y_true: The true class labels (ground truth)
  22. for the target variable. Shape should be (num_samples,).
  23. y_probas: The predicted probabilities or
  24. decision scores for each class. Shape should be (num_samples, num_classes).
  25. labels: Human-readable labels corresponding to the class
  26. indices in `y_true`. For example, if `labels=['dog', 'cat']`,
  27. class 0 will be displayed as 'dog' and class 1 as 'cat' in the plot.
  28. If None, the raw class indices from `y_true` will be used.
  29. Default is None.
  30. classes_to_plot: A subset of unique class labels
  31. to include in the ROC curve. If None, all classes in `y_true` will
  32. be plotted. Default is None.
  33. title: Title of the ROC curve plot. Default is "ROC Curve".
  34. split_table: Whether the table should be split into a separate
  35. section in the W&B UI. If `True`, the table will be displayed in a
  36. section named "Custom Chart Tables". Default is `False`.
  37. Returns:
  38. CustomChart: A custom chart object that can be logged to W&B. To log the
  39. chart, pass it to `wandb.log()`.
  40. Raises:
  41. wandb.Error: If numpy, pandas, or scikit-learn are not found.
  42. Example:
  43. ```python
  44. import numpy as np
  45. import wandb
  46. # Simulate a medical diagnosis classification problem with three diseases
  47. n_samples = 200
  48. n_classes = 3
  49. # True labels: assign "Diabetes", "Hypertension", or "Heart Disease" to
  50. # each sample
  51. disease_labels = ["Diabetes", "Hypertension", "Heart Disease"]
  52. # 0: Diabetes, 1: Hypertension, 2: Heart Disease
  53. y_true = np.random.choice([0, 1, 2], size=n_samples)
  54. # Predicted probabilities: simulate predictions, ensuring they sum to 1
  55. # for each sample
  56. y_probas = np.random.dirichlet(np.ones(n_classes), size=n_samples)
  57. # Specify classes to plot (plotting all three diseases)
  58. classes_to_plot = [0, 1, 2]
  59. # Initialize a W&B run and log a ROC curve plot for disease classification
  60. with wandb.init(project="medical_diagnosis") as run:
  61. roc_plot = wandb.plot.roc_curve(
  62. y_true=y_true,
  63. y_probas=y_probas,
  64. labels=disease_labels,
  65. classes_to_plot=classes_to_plot,
  66. title="ROC Curve for Disease Classification",
  67. )
  68. run.log({"roc-curve": roc_plot})
  69. ```
  70. """
  71. np = util.get_module(
  72. "numpy",
  73. required="roc requires the numpy library, install with `pip install numpy`",
  74. )
  75. pd = util.get_module(
  76. "pandas",
  77. required="roc requires the pandas library, install with `pip install pandas`",
  78. )
  79. sklearn_metrics = util.get_module(
  80. "sklearn.metrics",
  81. "roc requires the scikit library, install with `pip install scikit-learn`",
  82. )
  83. sklearn_utils = util.get_module(
  84. "sklearn.utils",
  85. "roc requires the scikit library, install with `pip install scikit-learn`",
  86. )
  87. y_true = np.array(y_true)
  88. y_probas = np.array(y_probas)
  89. if not test_missing(y_true=y_true, y_probas=y_probas):
  90. return
  91. if not test_types(y_true=y_true, y_probas=y_probas):
  92. return
  93. classes = np.unique(y_true)
  94. if classes_to_plot is None:
  95. classes_to_plot = classes
  96. fpr = {}
  97. tpr = {}
  98. indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
  99. for i in indices_to_plot:
  100. if labels is not None and (
  101. isinstance(classes[i], int) or isinstance(classes[0], np.integer)
  102. ):
  103. class_label = labels[classes[i]]
  104. else:
  105. class_label = classes[i]
  106. fpr[class_label], tpr[class_label], _ = sklearn_metrics.roc_curve(
  107. y_true, y_probas[..., i], pos_label=classes[i]
  108. )
  109. df = pd.DataFrame(
  110. {
  111. "class": np.hstack([[k] * len(v) for k, v in fpr.items()]),
  112. "fpr": np.hstack(list(fpr.values())),
  113. "tpr": np.hstack(list(tpr.values())),
  114. }
  115. ).round(3)
  116. if len(df) > wandb.Table.MAX_ROWS:
  117. wandb.termwarn(
  118. f"wandb uses only {wandb.Table.MAX_ROWS} data points to create the plots."
  119. )
  120. # different sampling could be applied, possibly to ensure endpoints are kept
  121. df = sklearn_utils.resample(
  122. df,
  123. replace=False,
  124. n_samples=wandb.Table.MAX_ROWS,
  125. random_state=42,
  126. stratify=df["class"],
  127. ).sort_values(["fpr", "tpr", "class"])
  128. return plot_table(
  129. data_table=wandb.Table(dataframe=df),
  130. vega_spec_name="wandb/area-under-curve/v0",
  131. fields={
  132. "x": "fpr",
  133. "y": "tpr",
  134. "class": "class",
  135. },
  136. string_fields={
  137. "title": title,
  138. "x-axis-title": "False positive rate",
  139. "y-axis-title": "True positive rate",
  140. },
  141. split_table=split_table,
  142. )