multioutput.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Mapping, Sequence
  15. from copy import deepcopy
  16. from typing import Any, Optional, Union, cast
  17. import torch
  18. from lightning_utilities import apply_to_collection
  19. from torch import Tensor
  20. from torch.nn import ModuleList
  21. from torchmetrics.metric import Metric
  22. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  23. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  24. from torchmetrics.wrappers.abstract import WrapperMetric
  25. if not _MATPLOTLIB_AVAILABLE:
  26. __doctest_skip__ = ["MultioutputWrapper.plot"]
  27. def _get_nan_indices(*tensors: Tensor) -> Tensor:
  28. """Get indices of rows along dim 0 which have NaN values."""
  29. if len(tensors) == 0:
  30. raise ValueError("Must pass at least one tensor as argument")
  31. sentinel = tensors[0]
  32. nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool, device=sentinel.device)
  33. for tensor in tensors:
  34. permuted_tensor = tensor.flatten(start_dim=1)
  35. nan_idxs |= torch.any(torch.isnan(permuted_tensor), dim=1)
  36. return nan_idxs
  37. class MultioutputWrapper(WrapperMetric):
  38. """Wrap a base metric to enable it to support multiple outputs.
  39. Several torchmetrics metrics, such as :class:`~torchmetrics.regression.spearman.SpearmanCorrCoef` lack support for
  40. multioutput mode. This class wraps such metrics to support computing one metric per output.
  41. Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs.
  42. This means if you set ``num_outputs`` to 2, ``.compute()`` will return a Tensor of dimension
  43. ``(2, ...)`` where ``...`` represents the dimensions the metric returns when not wrapped.
  44. In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude
  45. fashion, dealing with missing labels (or other data). When ``remove_nans`` is passed, the class will remove the
  46. intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses
  47. `MultioutputWrapper` to wrap :class:`torchmetrics.regression.r2.R2Score` with 2 outputs, one of which occasionally
  48. has missing labels for classes like ``R2Score`` is that this class supports removing ``NaN`` values
  49. (parameter ``remove_nans``) on a per-output basis. When ``remove_nans`` is passed the wrapper will remove all rows
  50. Args:
  51. base_metric: Metric being wrapped.
  52. num_outputs: Expected dimensionality of the output dimension.
  53. This parameter is used to determine the number of distinct metrics we need to track.
  54. output_dim:
  55. Dimension on which output is expected. Note that while this provides some flexibility, the output dimension
  56. must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels
  57. can have a different number of dimensions than the predictions. This can be worked around if the output
  58. dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.
  59. remove_nans:
  60. Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying
  61. metric. Proper operation requires all tensors passed to update to have dimension ``(N, ...)`` where N
  62. represents the length of the batch or dataset being passed in.
  63. squeeze_outputs:
  64. If ``True``, will squeeze the 1-item dimensions left after ``index_select`` is applied.
  65. This is sometimes unnecessary but harmless for metrics such as `R2Score` but useful
  66. for certain classification metrics that can't handle additional 1-item dimensions.
  67. Example:
  68. >>> # Mimic R2Score in `multioutput`, `raw_values` mode:
  69. >>> import torch
  70. >>> from torchmetrics.wrappers import MultioutputWrapper
  71. >>> from torchmetrics.regression import R2Score
  72. >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
  73. >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
  74. >>> r2score = MultioutputWrapper(R2Score(), 2)
  75. >>> r2score(preds, target)
  76. tensor([0.9654, 0.9082])
  77. """
  78. is_differentiable = False
  79. def __init__(
  80. self,
  81. base_metric: Metric,
  82. num_outputs: int,
  83. output_dim: int = -1,
  84. remove_nans: bool = True,
  85. squeeze_outputs: bool = True,
  86. ) -> None:
  87. super().__init__()
  88. self.metrics = ModuleList([deepcopy(base_metric) for _ in range(num_outputs)])
  89. self.output_dim = output_dim
  90. self.remove_nans = remove_nans
  91. self.squeeze_outputs = squeeze_outputs
  92. def _get_args_kwargs_by_output(self, *args: Tensor, **kwargs: Tensor) -> list[tuple[Tensor, Tensor]]:
  93. """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out."""
  94. args_kwargs_by_output = []
  95. for i in range(len(self.metrics)):
  96. selected_args = apply_to_collection(
  97. args, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device)
  98. )
  99. selected_kwargs = apply_to_collection(
  100. kwargs, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device)
  101. )
  102. if self.remove_nans:
  103. args_kwargs = selected_args + tuple(selected_kwargs.values())
  104. nan_idxs = _get_nan_indices(*args_kwargs)
  105. selected_args = [arg[~nan_idxs] for arg in selected_args]
  106. selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()}
  107. if self.squeeze_outputs:
  108. selected_args = [arg.squeeze(self.output_dim) for arg in selected_args]
  109. selected_kwargs = {k: v.squeeze(self.output_dim) for k, v in selected_kwargs.items()}
  110. args_kwargs_by_output.append((selected_args, selected_kwargs))
  111. return args_kwargs_by_output
  112. def update(self, *args: Any, **kwargs: Any) -> None:
  113. """Update each underlying metric with the corresponding output."""
  114. reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs)
  115. for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs):
  116. cast(Metric, metric).update(*selected_args, **cast(Mapping, selected_kwargs))
  117. def compute(self) -> Tensor:
  118. """Compute metrics."""
  119. return torch.stack([cast(Metric, m).compute() for m in self.metrics], 0)
  120. @torch.jit.unused
  121. def forward(self, *args: Any, **kwargs: Any) -> Any:
  122. """Call underlying forward methods and aggregate the results if they're non-null.
  123. We override this method to ensure that state variables get copied over on the underlying metrics.
  124. """
  125. reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs)
  126. results = [
  127. metric(*selected_args, **cast(Mapping, selected_kwargs))
  128. for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs)
  129. ]
  130. if results[0] is None:
  131. return None
  132. return torch.stack(results, 0)
  133. def reset(self) -> None:
  134. """Reset all underlying metrics."""
  135. for metric in self.metrics:
  136. cast(Metric, metric).reset()
  137. super().reset()
  138. def plot(
  139. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  140. ) -> _PLOT_OUT_TYPE:
  141. """Plot a single or multiple values from the metric.
  142. Args:
  143. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  144. If no value is provided, will automatically call `metric.compute` and plot that result.
  145. ax: An matplotlib axis object. If provided will add plot to that axis
  146. Returns:
  147. Figure and Axes object
  148. Raises:
  149. ModuleNotFoundError:
  150. If `matplotlib` is not installed
  151. .. plot::
  152. :scale: 75
  153. >>> # Example plotting a single value
  154. >>> import torch
  155. >>> from torchmetrics.wrappers import MultioutputWrapper
  156. >>> from torchmetrics.regression import R2Score
  157. >>> metric = MultioutputWrapper(R2Score(), 2)
  158. >>> metric.update(torch.randn(20, 2), torch.randn(20, 2))
  159. >>> fig_, ax_ = metric.plot()
  160. .. plot::
  161. :scale: 75
  162. >>> # Example plotting multiple values
  163. >>> import torch
  164. >>> from torchmetrics.wrappers import MultioutputWrapper
  165. >>> from torchmetrics.regression import R2Score
  166. >>> metric = MultioutputWrapper(R2Score(), 2)
  167. >>> values = [ ]
  168. >>> for _ in range(3):
  169. ... values.append(metric(torch.randn(20, 2), torch.randn(20, 2)))
  170. >>> fig_, ax_ = metric.plot(values)
  171. """
  172. return self._plot(val, ax)