plot.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Generator, Sequence
  15. from itertools import product
  16. from math import ceil, floor, sqrt
  17. from typing import Any, List, Optional, Union, no_type_check
  18. import numpy as np
  19. import torch
  20. from torch import Tensor
  21. from torchmetrics.utilities.imports import _LATEX_AVAILABLE, _MATPLOTLIB_AVAILABLE, _SCIENCEPLOT_AVAILABLE
  22. if _MATPLOTLIB_AVAILABLE:
  23. import matplotlib
  24. import matplotlib.axes
  25. import matplotlib.pyplot as plt
  26. _PLOT_OUT_TYPE = tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]]
  27. _AX_TYPE = matplotlib.axes.Axes
  28. _CMAP_TYPE = Union[matplotlib.colors.Colormap, str]
  29. style_change = plt.style.context
  30. else:
  31. _PLOT_OUT_TYPE = tuple[object, object] # type: ignore[misc]
  32. _AX_TYPE = object
  33. _CMAP_TYPE = object # type: ignore[misc]
  34. from contextlib import contextmanager
  35. @contextmanager
  36. def style_change(*args: Any, **kwargs: Any) -> Generator:
  37. """No-ops decorator if matplotlib is not installed."""
  38. yield
  39. if _SCIENCEPLOT_AVAILABLE:
  40. import scienceplots # noqa: F401
  41. _style = ["science", "no-latex"]
  42. _style = ["science"] if _SCIENCEPLOT_AVAILABLE and _LATEX_AVAILABLE else ["default"]
  43. def _error_on_missing_matplotlib() -> None:
  44. """Raise error if matplotlib is not installed."""
  45. if not _MATPLOTLIB_AVAILABLE:
  46. raise ModuleNotFoundError(
  47. "Plot function expects `matplotlib` to be installed. Please install with `pip install matplotlib`"
  48. )
  49. @style_change(_style)
  50. def plot_single_or_multi_val(
  51. val: Union[Tensor, Sequence[Tensor], dict[str, Tensor], Sequence[dict[str, Tensor]]],
  52. ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type]
  53. higher_is_better: Optional[bool] = None,
  54. lower_bound: Optional[float] = None,
  55. upper_bound: Optional[float] = None,
  56. legend_name: Optional[str] = None,
  57. name: Optional[str] = None,
  58. ) -> _PLOT_OUT_TYPE:
  59. """Plot a single metric value or multiple, including bounds of value if existing.
  60. Args:
  61. val: A single tensor with one or multiple values (multiclass/label/output format) or a list of such tensors.
  62. If a list is provided the values are interpreted as a time series of evolving values.
  63. ax: Axis from a figure.
  64. higher_is_better: Indicates if a label indicating where the optimal value it should be added to the figure
  65. lower_bound: lower value that the metric can take
  66. upper_bound: upper value that the metric can take
  67. legend_name: for class based metrics specify the legend prefix e.g. Class or Label to use when multiple values
  68. are provided
  69. name: Name of the metric to use for the y-axis label
  70. Returns:
  71. A tuple consisting of the figure and respective ax objects of the generated figure
  72. Raises:
  73. ModuleNotFoundError:
  74. If `matplotlib` is not installed
  75. """
  76. _error_on_missing_matplotlib()
  77. fig, ax = plt.subplots() if ax is None else (None, ax)
  78. ax.get_xaxis().set_visible(False)
  79. if isinstance(val, Tensor):
  80. if val.numel() == 1:
  81. ax.plot([val.detach().cpu()], marker="o", markersize=10)
  82. else:
  83. for i, v in enumerate(val):
  84. label = f"{legend_name} {i}" if legend_name else f"{i}"
  85. ax.plot(i, v.detach().cpu(), marker="o", markersize=10, linestyle="None", label=label)
  86. elif isinstance(val, dict):
  87. for i, (k, v) in enumerate(val.items()):
  88. if v.numel() != 1:
  89. ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k)
  90. ax.get_xaxis().set_visible(True)
  91. ax.set_xlabel("Step")
  92. ax.set_xticks(torch.arange(len(v)))
  93. else:
  94. ax.plot(i, v.detach().cpu(), marker="o", markersize=10, label=k)
  95. elif isinstance(val, Sequence):
  96. n_steps = len(val)
  97. if isinstance(val[0], dict):
  98. val = {k: torch.stack([val[i][k] for i in range(n_steps)]) for k in val[0]} # type: ignore
  99. for k, v in val.items():
  100. ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k)
  101. else:
  102. val = torch.stack(val, 0) # type: ignore
  103. multi_series = val.ndim != 1
  104. val = val.T if multi_series else val.unsqueeze(0)
  105. for i, v in enumerate(val):
  106. label = (f"{legend_name} {i}" if legend_name else f"{i}") if multi_series else ""
  107. ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=label)
  108. ax.get_xaxis().set_visible(True)
  109. ax.set_xlabel("Step")
  110. ax.set_xticks(torch.arange(n_steps))
  111. else:
  112. raise ValueError("Got unknown format for argument `val`.")
  113. handles, labels = ax.get_legend_handles_labels()
  114. if handles and labels:
  115. ax.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, shadow=True)
  116. ylim = ax.get_ylim()
  117. if lower_bound is not None and upper_bound is not None:
  118. factor = 0.1 * (upper_bound - lower_bound)
  119. else:
  120. factor = 0.1 * (ylim[1] - ylim[0])
  121. ax.set_ylim(
  122. bottom=lower_bound - factor if lower_bound is not None else ylim[0] - factor,
  123. top=upper_bound + factor if upper_bound is not None else ylim[1] + factor,
  124. )
  125. ax.grid(True)
  126. ax.set_ylabel(name if name is not None else None)
  127. xlim = ax.get_xlim()
  128. factor = 0.1 * (xlim[1] - xlim[0])
  129. y_lines = []
  130. if lower_bound is not None:
  131. y_lines.append(lower_bound)
  132. if upper_bound is not None:
  133. y_lines.append(upper_bound)
  134. ax.hlines(y_lines, xlim[0], xlim[1], linestyles="dashed", colors="k")
  135. if higher_is_better is not None:
  136. if lower_bound is not None and not higher_is_better:
  137. ax.set_xlim(xlim[0] - factor, xlim[1])
  138. ax.text(
  139. xlim[0], lower_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center"
  140. )
  141. if upper_bound is not None and higher_is_better:
  142. ax.set_xlim(xlim[0] - factor, xlim[1])
  143. ax.text(
  144. xlim[0], upper_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center"
  145. )
  146. return fig, ax
  147. def _get_col_row_split(n: int) -> tuple[int, int]:
  148. """Split `n` figures into `rows` x `cols` figures."""
  149. nsq = sqrt(n)
  150. if int(nsq) == nsq: # square number
  151. return int(nsq), int(nsq)
  152. if floor(nsq) * ceil(nsq) >= n:
  153. return floor(nsq), ceil(nsq)
  154. return ceil(nsq), ceil(nsq)
  155. def _get_text_color(patch_color: tuple[float, float, float, float]) -> str:
  156. """Get the text color for a given value and colormap.
  157. Following Wikipedia's recommendations: https://en.wikipedia.org/wiki/Relative_luminance.
  158. Args:
  159. patch_color: RGBA color tuple
  160. """
  161. # Convert to linear color space
  162. r, g, b, a = patch_color
  163. r, g, b = (c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4 for c in (r, g, b))
  164. # Get the relative luminance
  165. y = 0.2126 * r + 0.7152 * g + 0.0722 * b
  166. return ".1" if y > 0.4 else "white"
  167. def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> Union[np.ndarray, _AX_TYPE]: # type: ignore[valid-type]
  168. """Reduce `axs` to `nb` Axes.
  169. All further Axes are removed from the figure.
  170. """
  171. if isinstance(axs, _AX_TYPE):
  172. return axs
  173. axs = axs.flat # type: ignore[union-attr]
  174. for ax in axs[nb:]:
  175. ax.remove()
  176. return axs[:nb]
  177. @style_change(_style)
  178. @no_type_check
  179. def plot_confusion_matrix(
  180. confmat: Tensor,
  181. ax: Optional[_AX_TYPE] = None,
  182. add_text: bool = True,
  183. labels: Optional[list[Union[int, str]]] = None,
  184. cmap: Optional[_CMAP_TYPE] = None,
  185. ) -> _PLOT_OUT_TYPE:
  186. """Plot an confusion matrix.
  187. Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/confusion_matrix.py.
  188. Works for both binary, multiclass and multilabel confusion matrices.
  189. Args:
  190. confmat: the confusion matrix. Either should be an [N,N] matrix in the binary and multiclass cases or an
  191. [N, 2, 2] matrix for multilabel classification
  192. ax: Axis from a figure. If not provided, a new figure and axis will be created
  193. add_text: if text should be added to each cell with the given value
  194. labels: labels to add the x- and y-axis
  195. cmap: matplotlib colormap to use for the confusion matrix
  196. https://matplotlib.org/stable/users/explain/colors/colormaps.html
  197. Returns:
  198. A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
  199. Raises:
  200. ModuleNotFoundError:
  201. If `matplotlib` is not installed
  202. """
  203. _error_on_missing_matplotlib()
  204. if confmat.ndim == 3: # multilabel
  205. nb, n_classes = confmat.shape[0], 2
  206. rows, cols = _get_col_row_split(nb)
  207. else:
  208. nb, n_classes, rows, cols = 1, confmat.shape[0], 1, 1
  209. if labels is not None and confmat.ndim != 3 and len(labels) != n_classes:
  210. raise ValueError(
  211. "Expected number of elements in arg `labels` to match number of labels in confmat but "
  212. f"got {len(labels)} and {n_classes}"
  213. )
  214. if confmat.ndim == 3:
  215. fig_label = labels or np.arange(nb)
  216. labels = list(map(str, range(n_classes)))
  217. else:
  218. fig_label = None
  219. labels = labels or np.arange(n_classes).tolist()
  220. fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)
  221. axs = trim_axs(axs, nb)
  222. for i in range(nb):
  223. ax = axs[i] if (rows != 1 or cols != 1) else axs
  224. if fig_label is not None:
  225. ax.set_title(f"Label {fig_label[i]}", fontsize=15)
  226. im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap)
  227. if i // cols == rows - 1: # bottom row only
  228. ax.set_xlabel("Predicted class", fontsize=15)
  229. if i % cols == 0: # leftmost column only
  230. ax.set_ylabel("True class", fontsize=15)
  231. ax.set_xticks(list(range(n_classes)))
  232. ax.set_yticks(list(range(n_classes)))
  233. ax.set_xticklabels(labels, rotation=45, fontsize=10)
  234. ax.set_yticklabels(labels, rotation=25, fontsize=10)
  235. if add_text:
  236. for ii, jj in product(range(n_classes), range(n_classes)):
  237. val = confmat[i, ii, jj] if confmat.ndim == 3 else confmat[ii, jj]
  238. patch_color = im.cmap(im.norm(val.item()))
  239. c = _get_text_color(patch_color)
  240. ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15, color=c)
  241. return fig, axs
  242. @style_change(_style)
  243. def plot_curve(
  244. curve: Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]],
  245. score: Optional[Tensor] = None,
  246. ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type]
  247. label_names: Optional[tuple[str, str]] = None,
  248. legend_name: Optional[str] = None,
  249. name: Optional[str] = None,
  250. labels: Optional[list[Union[int, str]]] = None,
  251. ) -> _PLOT_OUT_TYPE:
  252. """Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py.
  253. Plots a curve object
  254. Args:
  255. curve: a tuple of (x, y, t) where x and y are the coordinates of the curve and t are the thresholds used
  256. to compute the curve
  257. score: optional area under the curve added as label to the plot
  258. ax: Axis from a figure
  259. label_names: Tuple containing the names of the x and y axis
  260. legend_name: Name of the curve to be used in the legend
  261. name: Custom name to describe the metric
  262. labels: Optional labels for the different curves that will be added to the plot
  263. Returns:
  264. A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
  265. Raises:
  266. ModuleNotFoundError:
  267. If `matplotlib` is not installed
  268. ValueError:
  269. If `curve` does not have 3 elements, being in the wrong format
  270. """
  271. if len(curve) < 2:
  272. raise ValueError(f"Expected 2 or 3 elements in curve but got {len(curve)}")
  273. x, y = curve[:2]
  274. _error_on_missing_matplotlib()
  275. fig, ax = plt.subplots() if ax is None else (None, ax)
  276. if isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 1 and y.ndim == 1:
  277. label = f"AUC={score.item():0.3f}" if score is not None else None
  278. ax.plot(x.detach().cpu(), y.detach().cpu(), linestyle="-", linewidth=2, label=label)
  279. if label is not None:
  280. ax.legend()
  281. elif (isinstance(x, list) and isinstance(y, list)) or (
  282. isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 2 and y.ndim == 2
  283. ):
  284. n_classes = len(x)
  285. if labels is not None and len(labels) != n_classes:
  286. raise ValueError(
  287. "Expected number of elements in arg `labels` to match number of labels in roc curves but "
  288. f"got {len(labels)} and {n_classes}"
  289. )
  290. for i, (x_, y_) in enumerate(zip(x, y)):
  291. label = f"{legend_name}_{i}" if legend_name is not None else str(i) if labels is None else str(labels[i])
  292. label += f" AUC={score[i].item():0.3f}" if score is not None else ""
  293. ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label)
  294. ax.legend()
  295. else:
  296. raise ValueError(
  297. f"Unknown format for argument `x` and `y`. Expected either list or tensors but got {type(x)} and {type(y)}."
  298. )
  299. if label_names is not None:
  300. ax.set_xlabel(label_names[0])
  301. ax.set_ylabel(label_names[1])
  302. ax.grid(True)
  303. ax.set_title(name)
  304. return fig, ax