pr_curve.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from __future__ import annotations
  2. import numbers
  3. from collections.abc import Iterable
  4. from typing import TYPE_CHECKING, TypeVar
  5. import wandb
  6. from wandb import util
  7. from wandb.plot.custom_chart import plot_table
  8. from wandb.plot.utils import test_missing, test_types
  9. if TYPE_CHECKING:
  10. from wandb.plot.custom_chart import CustomChart
  11. T = TypeVar("T")
  12. def pr_curve(
  13. y_true: Iterable[T] | None = None,
  14. y_probas: Iterable[numbers.Number] | None = None,
  15. labels: list[str] | None = None,
  16. classes_to_plot: list[T] | None = None,
  17. interp_size: int = 21,
  18. title: str = "Precision-Recall Curve",
  19. split_table: bool = False,
  20. ) -> CustomChart:
  21. """Constructs a Precision-Recall (PR) curve.
  22. The Precision-Recall curve is particularly useful for evaluating classifiers
  23. on imbalanced datasets. A high area under the PR curve signifies both high
  24. precision (a low false positive rate) and high recall (a low false negative
  25. rate). The curve provides insights into the balance between false positives
  26. and false negatives at various threshold levels, aiding in the assessment of
  27. a model's performance.
  28. Args:
  29. y_true: True binary labels. The shape should be (`num_samples`,).
  30. y_probas: Predicted scores or probabilities for each class.
  31. These can be probability estimates, confidence scores, or non-thresholded
  32. decision values. The shape should be (`num_samples`, `num_classes`).
  33. labels: Optional list of class names to replace
  34. numeric values in `y_true` for easier plot interpretation.
  35. For example, `labels = ['dog', 'cat', 'owl']` will replace 0 with
  36. 'dog', 1 with 'cat', and 2 with 'owl' in the plot. If not provided,
  37. numeric values from `y_true` will be used.
  38. classes_to_plot: Optional list of unique class values from
  39. y_true to be included in the plot. If not specified, all unique
  40. classes in y_true will be plotted.
  41. interp_size: Number of points to interpolate recall values. The
  42. recall values will be fixed to `interp_size` uniformly distributed
  43. points in the range [0, 1], and the precision will be interpolated
  44. accordingly.
  45. title: Title of the plot. Defaults to "Precision-Recall Curve".
  46. split_table: Whether the table should be split into a separate section
  47. in the W&B UI. If `True`, the table will be displayed in a section named
  48. "Custom Chart Tables". Default is `False`.
  49. Returns:
  50. CustomChart: A custom chart object that can be logged to W&B. To log the
  51. chart, pass it to `wandb.log()`.
  52. Raises:
  53. wandb.Error: If NumPy, pandas, or scikit-learn is not installed.
  54. Example:
  55. ```python
  56. import wandb
  57. # Example for spam detection (binary classification)
  58. y_true = [0, 1, 1, 0, 1] # 0 = not spam, 1 = spam
  59. y_probas = [
  60. [0.9, 0.1], # Predicted probabilities for the first sample (not spam)
  61. [0.2, 0.8], # Second sample (spam), and so on
  62. [0.1, 0.9],
  63. [0.8, 0.2],
  64. [0.3, 0.7],
  65. ]
  66. labels = ["not spam", "spam"] # Optional class names for readability
  67. with wandb.init(project="spam-detection") as run:
  68. pr_curve = wandb.plot.pr_curve(
  69. y_true=y_true,
  70. y_probas=y_probas,
  71. labels=labels,
  72. title="Precision-Recall Curve for Spam Detection",
  73. )
  74. run.log({"pr-curve": pr_curve})
  75. ```
  76. """
  77. np = util.get_module(
  78. "numpy",
  79. required="roc requires the numpy library, install with `pip install numpy`",
  80. )
  81. pd = util.get_module(
  82. "pandas",
  83. required="roc requires the pandas library, install with `pip install pandas`",
  84. )
  85. sklearn_metrics = util.get_module(
  86. "sklearn.metrics",
  87. "roc requires the scikit library, install with `pip install scikit-learn`",
  88. )
  89. sklearn_utils = util.get_module(
  90. "sklearn.utils",
  91. "roc requires the scikit library, install with `pip install scikit-learn`",
  92. )
  93. def _step(x):
  94. y = np.array(x)
  95. for i in range(1, len(y)):
  96. y[i] = max(y[i], y[i - 1])
  97. return y
  98. y_true = np.array(y_true)
  99. y_probas = np.array(y_probas)
  100. if not test_missing(y_true=y_true, y_probas=y_probas):
  101. return
  102. if not test_types(y_true=y_true, y_probas=y_probas):
  103. return
  104. classes = np.unique(y_true)
  105. if classes_to_plot is None:
  106. classes_to_plot = classes
  107. precision = {}
  108. interp_recall = np.linspace(0, 1, interp_size)[::-1]
  109. indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
  110. for i in indices_to_plot:
  111. if labels is not None and (
  112. isinstance(classes[i], int) or isinstance(classes[0], np.integer)
  113. ):
  114. class_label = labels[classes[i]]
  115. else:
  116. class_label = classes[i]
  117. cur_precision, cur_recall, _ = sklearn_metrics.precision_recall_curve(
  118. y_true, y_probas[:, i], pos_label=classes[i]
  119. )
  120. # smooth the precision (monotonically increasing)
  121. cur_precision = _step(cur_precision)
  122. # reverse order so that recall in ascending
  123. cur_precision = cur_precision[::-1]
  124. cur_recall = cur_recall[::-1]
  125. indices = np.searchsorted(cur_recall, interp_recall, side="left")
  126. precision[class_label] = cur_precision[indices]
  127. df = pd.DataFrame(
  128. {
  129. "class": np.hstack([[k] * len(v) for k, v in precision.items()]),
  130. "precision": np.hstack(list(precision.values())),
  131. "recall": np.tile(interp_recall, len(precision)),
  132. }
  133. ).round(3)
  134. if len(df) > wandb.Table.MAX_ROWS:
  135. wandb.termwarn(
  136. f"Table has a limit of {wandb.Table.MAX_ROWS} rows. Resampling to fit."
  137. )
  138. # different sampling could be applied, possibly to ensure endpoints are kept
  139. df = sklearn_utils.resample(
  140. df,
  141. replace=False,
  142. n_samples=wandb.Table.MAX_ROWS,
  143. random_state=42,
  144. stratify=df["class"],
  145. ).sort_values(["precision", "recall", "class"])
  146. return plot_table(
  147. data_table=wandb.Table(dataframe=df),
  148. vega_spec_name="wandb/area-under-curve/v0",
  149. fields={
  150. "x": "recall",
  151. "y": "precision",
  152. "class": "class",
  153. },
  154. string_fields={
  155. "title": title,
  156. "x-axis-title": "Recall",
  157. "y-axis-title": "Precision",
  158. },
  159. split_table=split_table,
  160. )