| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Mapping, Sequence
- from copy import deepcopy
- from typing import Any, Optional, Union, cast
- import torch
- from lightning_utilities import apply_to_collection
- from torch import Tensor
- from torch.nn import ModuleList
- from torchmetrics.metric import Metric
- from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
- from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
- from torchmetrics.wrappers.abstract import WrapperMetric
- if not _MATPLOTLIB_AVAILABLE:
- __doctest_skip__ = ["MultioutputWrapper.plot"]
- def _get_nan_indices(*tensors: Tensor) -> Tensor:
- """Get indices of rows along dim 0 which have NaN values."""
- if len(tensors) == 0:
- raise ValueError("Must pass at least one tensor as argument")
- sentinel = tensors[0]
- nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool, device=sentinel.device)
- for tensor in tensors:
- permuted_tensor = tensor.flatten(start_dim=1)
- nan_idxs |= torch.any(torch.isnan(permuted_tensor), dim=1)
- return nan_idxs
- class MultioutputWrapper(WrapperMetric):
- """Wrap a base metric to enable it to support multiple outputs.
- Several torchmetrics metrics, such as :class:`~torchmetrics.regression.spearman.SpearmanCorrCoef` lack support for
- multioutput mode. This class wraps such metrics to support computing one metric per output.
- Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs.
- This means if you set ``num_outputs`` to 2, ``.compute()`` will return a Tensor of dimension
- ``(2, ...)`` where ``...`` represents the dimensions the metric returns when not wrapped.
- In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude
- fashion, dealing with missing labels (or other data). When ``remove_nans`` is passed, the class will remove the
- intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses
- `MultioutputWrapper` to wrap :class:`torchmetrics.regression.r2.R2Score` with 2 outputs, one of which occasionally
- has missing labels for classes like ``R2Score`` is that this class supports removing ``NaN`` values
- (parameter ``remove_nans``) on a per-output basis. When ``remove_nans`` is passed the wrapper will remove all rows
- Args:
- base_metric: Metric being wrapped.
- num_outputs: Expected dimensionality of the output dimension.
- This parameter is used to determine the number of distinct metrics we need to track.
- output_dim:
- Dimension on which output is expected. Note that while this provides some flexibility, the output dimension
- must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels
- can have a different number of dimensions than the predictions. This can be worked around if the output
- dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.
- remove_nans:
- Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying
- metric. Proper operation requires all tensors passed to update to have dimension ``(N, ...)`` where N
- represents the length of the batch or dataset being passed in.
- squeeze_outputs:
- If ``True``, will squeeze the 1-item dimensions left after ``index_select`` is applied.
- This is sometimes unnecessary but harmless for metrics such as `R2Score` but useful
- for certain classification metrics that can't handle additional 1-item dimensions.
- Example:
- >>> # Mimic R2Score in `multioutput`, `raw_values` mode:
- >>> import torch
- >>> from torchmetrics.wrappers import MultioutputWrapper
- >>> from torchmetrics.regression import R2Score
- >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
- >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
- >>> r2score = MultioutputWrapper(R2Score(), 2)
- >>> r2score(preds, target)
- tensor([0.9654, 0.9082])
- """
- is_differentiable = False
- def __init__(
- self,
- base_metric: Metric,
- num_outputs: int,
- output_dim: int = -1,
- remove_nans: bool = True,
- squeeze_outputs: bool = True,
- ) -> None:
- super().__init__()
- self.metrics = ModuleList([deepcopy(base_metric) for _ in range(num_outputs)])
- self.output_dim = output_dim
- self.remove_nans = remove_nans
- self.squeeze_outputs = squeeze_outputs
- def _get_args_kwargs_by_output(self, *args: Tensor, **kwargs: Tensor) -> list[tuple[Tensor, Tensor]]:
- """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out."""
- args_kwargs_by_output = []
- for i in range(len(self.metrics)):
- selected_args = apply_to_collection(
- args, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device)
- )
- selected_kwargs = apply_to_collection(
- kwargs, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device)
- )
- if self.remove_nans:
- args_kwargs = selected_args + tuple(selected_kwargs.values())
- nan_idxs = _get_nan_indices(*args_kwargs)
- selected_args = [arg[~nan_idxs] for arg in selected_args]
- selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()}
- if self.squeeze_outputs:
- selected_args = [arg.squeeze(self.output_dim) for arg in selected_args]
- selected_kwargs = {k: v.squeeze(self.output_dim) for k, v in selected_kwargs.items()}
- args_kwargs_by_output.append((selected_args, selected_kwargs))
- return args_kwargs_by_output
- def update(self, *args: Any, **kwargs: Any) -> None:
- """Update each underlying metric with the corresponding output."""
- reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs)
- for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs):
- cast(Metric, metric).update(*selected_args, **cast(Mapping, selected_kwargs))
- def compute(self) -> Tensor:
- """Compute metrics."""
- return torch.stack([cast(Metric, m).compute() for m in self.metrics], 0)
- @torch.jit.unused
- def forward(self, *args: Any, **kwargs: Any) -> Any:
- """Call underlying forward methods and aggregate the results if they're non-null.
- We override this method to ensure that state variables get copied over on the underlying metrics.
- """
- reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs)
- results = [
- metric(*selected_args, **cast(Mapping, selected_kwargs))
- for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs)
- ]
- if results[0] is None:
- return None
- return torch.stack(results, 0)
- def reset(self) -> None:
- """Reset all underlying metrics."""
- for metric in self.metrics:
- cast(Metric, metric).reset()
- super().reset()
- def plot(
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
- ) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single value
- >>> import torch
- >>> from torchmetrics.wrappers import MultioutputWrapper
- >>> from torchmetrics.regression import R2Score
- >>> metric = MultioutputWrapper(R2Score(), 2)
- >>> metric.update(torch.randn(20, 2), torch.randn(20, 2))
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> import torch
- >>> from torchmetrics.wrappers import MultioutputWrapper
- >>> from torchmetrics.regression import R2Score
- >>> metric = MultioutputWrapper(R2Score(), 2)
- >>> values = [ ]
- >>> for _ in range(3):
- ... values.append(metric(torch.randn(20, 2), torch.randn(20, 2)))
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
|