| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Generator, Sequence
- from itertools import product
- from math import ceil, floor, sqrt
- from typing import Any, List, Optional, Union, no_type_check
- import numpy as np
- import torch
- from torch import Tensor
- from torchmetrics.utilities.imports import _LATEX_AVAILABLE, _MATPLOTLIB_AVAILABLE, _SCIENCEPLOT_AVAILABLE
- if _MATPLOTLIB_AVAILABLE:
- import matplotlib
- import matplotlib.axes
- import matplotlib.pyplot as plt
- _PLOT_OUT_TYPE = tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]]
- _AX_TYPE = matplotlib.axes.Axes
- _CMAP_TYPE = Union[matplotlib.colors.Colormap, str]
- style_change = plt.style.context
- else:
- _PLOT_OUT_TYPE = tuple[object, object] # type: ignore[misc]
- _AX_TYPE = object
- _CMAP_TYPE = object # type: ignore[misc]
- from contextlib import contextmanager
- @contextmanager
- def style_change(*args: Any, **kwargs: Any) -> Generator:
- """No-ops decorator if matplotlib is not installed."""
- yield
- if _SCIENCEPLOT_AVAILABLE:
- import scienceplots # noqa: F401
- _style = ["science", "no-latex"]
- _style = ["science"] if _SCIENCEPLOT_AVAILABLE and _LATEX_AVAILABLE else ["default"]
- def _error_on_missing_matplotlib() -> None:
- """Raise error if matplotlib is not installed."""
- if not _MATPLOTLIB_AVAILABLE:
- raise ModuleNotFoundError(
- "Plot function expects `matplotlib` to be installed. Please install with `pip install matplotlib`"
- )
- @style_change(_style)
- def plot_single_or_multi_val(
- val: Union[Tensor, Sequence[Tensor], dict[str, Tensor], Sequence[dict[str, Tensor]]],
- ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type]
- higher_is_better: Optional[bool] = None,
- lower_bound: Optional[float] = None,
- upper_bound: Optional[float] = None,
- legend_name: Optional[str] = None,
- name: Optional[str] = None,
- ) -> _PLOT_OUT_TYPE:
- """Plot a single metric value or multiple, including bounds of value if existing.
- Args:
- val: A single tensor with one or multiple values (multiclass/label/output format) or a list of such tensors.
- If a list is provided the values are interpreted as a time series of evolving values.
- ax: Axis from a figure.
- higher_is_better: Indicates if a label indicating where the optimal value it should be added to the figure
- lower_bound: lower value that the metric can take
- upper_bound: upper value that the metric can take
- legend_name: for class based metrics specify the legend prefix e.g. Class or Label to use when multiple values
- are provided
- name: Name of the metric to use for the y-axis label
- Returns:
- A tuple consisting of the figure and respective ax objects of the generated figure
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- """
- _error_on_missing_matplotlib()
- fig, ax = plt.subplots() if ax is None else (None, ax)
- ax.get_xaxis().set_visible(False)
- if isinstance(val, Tensor):
- if val.numel() == 1:
- ax.plot([val.detach().cpu()], marker="o", markersize=10)
- else:
- for i, v in enumerate(val):
- label = f"{legend_name} {i}" if legend_name else f"{i}"
- ax.plot(i, v.detach().cpu(), marker="o", markersize=10, linestyle="None", label=label)
- elif isinstance(val, dict):
- for i, (k, v) in enumerate(val.items()):
- if v.numel() != 1:
- ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k)
- ax.get_xaxis().set_visible(True)
- ax.set_xlabel("Step")
- ax.set_xticks(torch.arange(len(v)))
- else:
- ax.plot(i, v.detach().cpu(), marker="o", markersize=10, label=k)
- elif isinstance(val, Sequence):
- n_steps = len(val)
- if isinstance(val[0], dict):
- val = {k: torch.stack([val[i][k] for i in range(n_steps)]) for k in val[0]} # type: ignore
- for k, v in val.items():
- ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k)
- else:
- val = torch.stack(val, 0) # type: ignore
- multi_series = val.ndim != 1
- val = val.T if multi_series else val.unsqueeze(0)
- for i, v in enumerate(val):
- label = (f"{legend_name} {i}" if legend_name else f"{i}") if multi_series else ""
- ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=label)
- ax.get_xaxis().set_visible(True)
- ax.set_xlabel("Step")
- ax.set_xticks(torch.arange(n_steps))
- else:
- raise ValueError("Got unknown format for argument `val`.")
- handles, labels = ax.get_legend_handles_labels()
- if handles and labels:
- ax.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, shadow=True)
- ylim = ax.get_ylim()
- if lower_bound is not None and upper_bound is not None:
- factor = 0.1 * (upper_bound - lower_bound)
- else:
- factor = 0.1 * (ylim[1] - ylim[0])
- ax.set_ylim(
- bottom=lower_bound - factor if lower_bound is not None else ylim[0] - factor,
- top=upper_bound + factor if upper_bound is not None else ylim[1] + factor,
- )
- ax.grid(True)
- ax.set_ylabel(name if name is not None else None)
- xlim = ax.get_xlim()
- factor = 0.1 * (xlim[1] - xlim[0])
- y_lines = []
- if lower_bound is not None:
- y_lines.append(lower_bound)
- if upper_bound is not None:
- y_lines.append(upper_bound)
- ax.hlines(y_lines, xlim[0], xlim[1], linestyles="dashed", colors="k")
- if higher_is_better is not None:
- if lower_bound is not None and not higher_is_better:
- ax.set_xlim(xlim[0] - factor, xlim[1])
- ax.text(
- xlim[0], lower_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center"
- )
- if upper_bound is not None and higher_is_better:
- ax.set_xlim(xlim[0] - factor, xlim[1])
- ax.text(
- xlim[0], upper_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center"
- )
- return fig, ax
- def _get_col_row_split(n: int) -> tuple[int, int]:
- """Split `n` figures into `rows` x `cols` figures."""
- nsq = sqrt(n)
- if int(nsq) == nsq: # square number
- return int(nsq), int(nsq)
- if floor(nsq) * ceil(nsq) >= n:
- return floor(nsq), ceil(nsq)
- return ceil(nsq), ceil(nsq)
- def _get_text_color(patch_color: tuple[float, float, float, float]) -> str:
- """Get the text color for a given value and colormap.
- Following Wikipedia's recommendations: https://en.wikipedia.org/wiki/Relative_luminance.
- Args:
- patch_color: RGBA color tuple
- """
- # Convert to linear color space
- r, g, b, a = patch_color
- r, g, b = (c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4 for c in (r, g, b))
- # Get the relative luminance
- y = 0.2126 * r + 0.7152 * g + 0.0722 * b
- return ".1" if y > 0.4 else "white"
- def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> Union[np.ndarray, _AX_TYPE]: # type: ignore[valid-type]
- """Reduce `axs` to `nb` Axes.
- All further Axes are removed from the figure.
- """
- if isinstance(axs, _AX_TYPE):
- return axs
- axs = axs.flat # type: ignore[union-attr]
- for ax in axs[nb:]:
- ax.remove()
- return axs[:nb]
- @style_change(_style)
- @no_type_check
- def plot_confusion_matrix(
- confmat: Tensor,
- ax: Optional[_AX_TYPE] = None,
- add_text: bool = True,
- labels: Optional[list[Union[int, str]]] = None,
- cmap: Optional[_CMAP_TYPE] = None,
- ) -> _PLOT_OUT_TYPE:
- """Plot an confusion matrix.
- Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/confusion_matrix.py.
- Works for both binary, multiclass and multilabel confusion matrices.
- Args:
- confmat: the confusion matrix. Either should be an [N,N] matrix in the binary and multiclass cases or an
- [N, 2, 2] matrix for multilabel classification
- ax: Axis from a figure. If not provided, a new figure and axis will be created
- add_text: if text should be added to each cell with the given value
- labels: labels to add the x- and y-axis
- cmap: matplotlib colormap to use for the confusion matrix
- https://matplotlib.org/stable/users/explain/colors/colormaps.html
- Returns:
- A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- """
- _error_on_missing_matplotlib()
- if confmat.ndim == 3: # multilabel
- nb, n_classes = confmat.shape[0], 2
- rows, cols = _get_col_row_split(nb)
- else:
- nb, n_classes, rows, cols = 1, confmat.shape[0], 1, 1
- if labels is not None and confmat.ndim != 3 and len(labels) != n_classes:
- raise ValueError(
- "Expected number of elements in arg `labels` to match number of labels in confmat but "
- f"got {len(labels)} and {n_classes}"
- )
- if confmat.ndim == 3:
- fig_label = labels or np.arange(nb)
- labels = list(map(str, range(n_classes)))
- else:
- fig_label = None
- labels = labels or np.arange(n_classes).tolist()
- fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)
- axs = trim_axs(axs, nb)
- for i in range(nb):
- ax = axs[i] if (rows != 1 or cols != 1) else axs
- if fig_label is not None:
- ax.set_title(f"Label {fig_label[i]}", fontsize=15)
- im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap)
- if i // cols == rows - 1: # bottom row only
- ax.set_xlabel("Predicted class", fontsize=15)
- if i % cols == 0: # leftmost column only
- ax.set_ylabel("True class", fontsize=15)
- ax.set_xticks(list(range(n_classes)))
- ax.set_yticks(list(range(n_classes)))
- ax.set_xticklabels(labels, rotation=45, fontsize=10)
- ax.set_yticklabels(labels, rotation=25, fontsize=10)
- if add_text:
- for ii, jj in product(range(n_classes), range(n_classes)):
- val = confmat[i, ii, jj] if confmat.ndim == 3 else confmat[ii, jj]
- patch_color = im.cmap(im.norm(val.item()))
- c = _get_text_color(patch_color)
- ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15, color=c)
- return fig, axs
- @style_change(_style)
- def plot_curve(
- curve: Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]],
- score: Optional[Tensor] = None,
- ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type]
- label_names: Optional[tuple[str, str]] = None,
- legend_name: Optional[str] = None,
- name: Optional[str] = None,
- labels: Optional[list[Union[int, str]]] = None,
- ) -> _PLOT_OUT_TYPE:
- """Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py.
- Plots a curve object
- Args:
- curve: a tuple of (x, y, t) where x and y are the coordinates of the curve and t are the thresholds used
- to compute the curve
- score: optional area under the curve added as label to the plot
- ax: Axis from a figure
- label_names: Tuple containing the names of the x and y axis
- legend_name: Name of the curve to be used in the legend
- name: Custom name to describe the metric
- labels: Optional labels for the different curves that will be added to the plot
- Returns:
- A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- ValueError:
- If `curve` does not have 3 elements, being in the wrong format
- """
- if len(curve) < 2:
- raise ValueError(f"Expected 2 or 3 elements in curve but got {len(curve)}")
- x, y = curve[:2]
- _error_on_missing_matplotlib()
- fig, ax = plt.subplots() if ax is None else (None, ax)
- if isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 1 and y.ndim == 1:
- label = f"AUC={score.item():0.3f}" if score is not None else None
- ax.plot(x.detach().cpu(), y.detach().cpu(), linestyle="-", linewidth=2, label=label)
- if label is not None:
- ax.legend()
- elif (isinstance(x, list) and isinstance(y, list)) or (
- isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 2 and y.ndim == 2
- ):
- n_classes = len(x)
- if labels is not None and len(labels) != n_classes:
- raise ValueError(
- "Expected number of elements in arg `labels` to match number of labels in roc curves but "
- f"got {len(labels)} and {n_classes}"
- )
- for i, (x_, y_) in enumerate(zip(x, y)):
- label = f"{legend_name}_{i}" if legend_name is not None else str(i) if labels is None else str(labels[i])
- label += f" AUC={score[i].item():0.3f}" if score is not None else ""
- ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label)
- ax.legend()
- else:
- raise ValueError(
- f"Unknown format for argument `x` and `y`. Expected either list or tensors but got {type(x)} and {type(y)}."
- )
- if label_names is not None:
- ax.set_xlabel(label_names[0])
- ax.set_ylabel(label_names[1])
- ax.grid(True)
- ax.set_title(name)
- return fig, ax
|