| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # this is just a bypass for this module name collision with built-in one
- from collections.abc import Iterable, Sequence
- from copy import deepcopy
- from typing import Any, Optional, Union
- from torch import Tensor, nn
- from torchmetrics.collections import MetricCollection
- from torchmetrics.metric import Metric
- from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
- from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
- from torchmetrics.wrappers.abstract import WrapperMetric
- if not _MATPLOTLIB_AVAILABLE:
- __doctest_skip__ = ["MultitaskWrapper.plot"]
- class MultitaskWrapper(WrapperMetric):
- """Wrapper class for computing different metrics on different tasks in the context of multitask learning.
- In multitask learning the different tasks requires different metrics to be evaluated. This wrapper allows
- for easy evaluation in such cases by supporting multiple predictions and targets through a dictionary.
- Note that only metrics where the signature of `update` follows the standard `preds, target` is supported.
- Args:
- task_metrics:
- Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the
- names of the tasks, and the values represent the metrics to use for each task.
- prefix:
- A string to append in front of the metric keys. If not provided, will default to an empty string.
- postfix:
- A string to append after the keys of the output dict. If not provided, will default to an empty string.
- .. tip::
- The use prefix and postfix allows for easily creating task wrappers for training, validation and test.
- The arguments are only changing the output keys of the computed metrics and not the input keys. This means
- that a ``MultitaskWrapper`` initialized as ``MultitaskWrapper({"task": Metric()}, prefix="train_")`` will
- still expect the input to be a dictionary with the key "task", but the output will be a dictionary with the key
- "train_task".
- Raises:
- TypeError:
- If argument `task_metrics` is not an dictionary
- TypeError:
- If not all values in the `task_metrics` dictionary is instances of `Metric` or `MetricCollection`
- ValueError:
- If `prefix` is not a string
- ValueError:
- If `postfix` is not a string
- Example (with a single metric per class):
- >>> import torch
- >>> from torchmetrics.wrappers import MultitaskWrapper
- >>> from torchmetrics.regression import MeanSquaredError
- >>> from torchmetrics.classification import BinaryAccuracy
- >>>
- >>> classification_target = torch.tensor([0, 1, 0])
- >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
- >>> targets = {"Classification": classification_target, "Regression": regression_target}
- >>>
- >>> classification_preds = torch.tensor([0, 0, 1])
- >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
- >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
- >>>
- >>> metrics = MultitaskWrapper({
- ... "Classification": BinaryAccuracy(),
- ... "Regression": MeanSquaredError()
- ... })
- >>> metrics.update(preds, targets)
- >>> metrics.compute()
- {'Classification': tensor(0.3333), 'Regression': tensor(0.8333)}
- Example (with several metrics per task):
- >>> import torch
- >>> from torchmetrics import MetricCollection
- >>> from torchmetrics.wrappers import MultitaskWrapper
- >>> from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
- >>> from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
- >>>
- >>> classification_target = torch.tensor([0, 1, 0])
- >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
- >>> targets = {"Classification": classification_target, "Regression": regression_target}
- >>>
- >>> classification_preds = torch.tensor([0, 0, 1])
- >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
- >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
- >>>
- >>> metrics = MultitaskWrapper({
- ... "Classification": MetricCollection(BinaryAccuracy(), BinaryF1Score()),
- ... "Regression": MetricCollection(MeanSquaredError(), MeanAbsoluteError())
- ... })
- >>> metrics.update(preds, targets)
- >>> metrics.compute()
- {'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)},
- 'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}}
- Example (with a prefix and postfix):
- >>> import torch
- >>> from torchmetrics.wrappers import MultitaskWrapper
- >>> from torchmetrics.regression import MeanSquaredError
- >>> from torchmetrics.classification import BinaryAccuracy
- >>>
- >>> classification_target = torch.tensor([0, 1, 0])
- >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
- >>> targets = {"Classification": classification_target, "Regression": regression_target}
- >>> classification_preds = torch.tensor([0, 0, 1])
- >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
- >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
- >>>
- >>> metrics = MultitaskWrapper({
- ... "Classification": BinaryAccuracy(),
- ... "Regression": MeanSquaredError()
- ... }, prefix="train_")
- >>> metrics.update(preds, targets)
- >>> metrics.compute()
- {'train_Classification': tensor(0.3333), 'train_Regression': tensor(0.8333)}
- """
- is_differentiable: bool = False
- def __init__(
- self,
- task_metrics: dict[str, Union[Metric, MetricCollection]],
- prefix: Optional[str] = None,
- postfix: Optional[str] = None,
- ) -> None:
- super().__init__()
- if not isinstance(task_metrics, dict):
- raise TypeError(f"Expected argument `task_metrics` to be a dict. Found task_metrics = {task_metrics}")
- for metric in task_metrics.values():
- if not (isinstance(metric, (Metric, MetricCollection))):
- raise TypeError(
- "Expected each task's metric to be a Metric or a MetricCollection. "
- f"Found a metric of type {type(metric)}"
- )
- self.task_metrics = nn.ModuleDict(task_metrics)
- if prefix is not None and not isinstance(prefix, str):
- raise ValueError(f"Expected argument `prefix` to either be `None` or a string but got {prefix}")
- self._prefix = prefix or ""
- if postfix is not None and not isinstance(postfix, str):
- raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}")
- self._postfix = postfix or ""
- def items(self, flatten: bool = True) -> Iterable[tuple[str, nn.Module]]:
- """Iterate over task and task metrics.
- Args:
- flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
- If False, will iterate over the task names and the corresponding metrics.
- """
- for task_name, metric in self.task_metrics.items():
- if flatten and isinstance(metric, MetricCollection):
- for sub_metric_name, sub_metric in metric.items():
- yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}", sub_metric
- else:
- yield f"{self._prefix}{task_name}{self._postfix}", metric
- def keys(self, flatten: bool = True) -> Iterable[str]:
- """Iterate over task names.
- Args:
- flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
- If False, will iterate over the task names and the corresponding metrics.
- """
- for task_name, metric in self.task_metrics.items():
- if flatten and isinstance(metric, MetricCollection):
- for sub_metric_name in metric:
- yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}"
- else:
- yield f"{self._prefix}{task_name}{self._postfix}"
- def values(self, flatten: bool = True) -> Iterable[nn.Module]:
- """Iterate over task metrics.
- Args:
- flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
- If False, will iterate over the task names and the corresponding metrics.
- """
- for metric in self.task_metrics.values():
- if flatten and isinstance(metric, MetricCollection):
- yield from metric.values()
- else:
- yield metric
- def update(self, task_preds: dict[str, Any], task_targets: dict[str, Any]) -> None:
- """Update each task's metric with its corresponding pred and target.
- Args:
- task_preds: Dictionary associating each task to a Tensor of pred.
- task_targets: Dictionary associating each task to a Tensor of target.
- """
- if not self.task_metrics.keys() == task_preds.keys() == task_targets.keys():
- raise ValueError(
- "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`"
- f". Found task_preds.keys() = {task_preds.keys()}, task_targets.keys() = {task_targets.keys()} "
- f"and self.task_metrics.keys() = {self.task_metrics.keys()}"
- )
- for task_name, metric in self.task_metrics.items():
- pred = task_preds[task_name]
- target = task_targets[task_name]
- if not (isinstance(metric, (Metric, MetricCollection))):
- raise TypeError(
- "Expected each task's metric to be a Metric or a MetricCollection. "
- f"Found a metric of type {type(metric)}"
- )
- metric.update(pred, target)
- def _convert_output(self, output: dict[str, Any]) -> dict[str, Any]:
- """Convert the output of the underlying metrics to a dictionary with the task names as keys."""
- return {f"{self._prefix}{task_name}{self._postfix}": task_output for task_name, task_output in output.items()}
- def compute(self) -> dict[str, Any]:
- """Compute metrics for all tasks."""
- output: dict[str, Any] = {}
- for task_name, metric in self.task_metrics.items():
- if not isinstance(metric, (Metric, MetricCollection)):
- raise TypeError(
- "Expected each task's metric to be a Metric or a MetricCollection. "
- f"Found a metric of type {type(metric)}"
- )
- output[task_name] = metric.compute()
- return self._convert_output(output)
- def forward(self, task_preds: dict[str, Tensor], task_targets: dict[str, Tensor]) -> dict[str, Any]:
- """Call underlying forward methods for all tasks and return the result as a dictionary."""
- # This method is overridden because we do not need the complex version defined in Metric, that relies on the
- # value of full_state_update, and that also accumulates the results. Here, all computations are handled by the
- # underlying metrics, which all have their own value of full_state_update, and which all accumulate the results
- # by themselves.
- return self._convert_output({
- task_name: metric(task_preds[task_name], task_targets[task_name])
- for task_name, metric in self.task_metrics.items()
- })
- def reset(self) -> None:
- """Reset all underlying metrics."""
- for metric in self.task_metrics.values():
- if not isinstance(metric, (Metric, MetricCollection)):
- raise TypeError(
- "Expected each task's metric to be a Metric or a MetricCollection. "
- f"Found a metric of type {type(metric)}"
- )
- metric.reset()
- super().reset()
- @staticmethod
- def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
- if arg is None or isinstance(arg, str):
- return arg
- raise ValueError(f"Expected input `{name}` to be a string, but got {type(arg)}")
- def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MultitaskWrapper":
- """Make a copy of the metric.
- Args:
- prefix: a string to append in front of the metric keys
- postfix: a string to append after the keys of the output dict.
- """
- multitask_copy = deepcopy(self)
- multitask_copy._prefix = self._check_arg(prefix, "prefix") or ""
- multitask_copy._postfix = self._check_arg(postfix, "prefix") or ""
- return multitask_copy
- def plot(
- self, val: Optional[Union[dict, Sequence[dict]]] = None, axes: Optional[Sequence[_AX_TYPE]] = None
- ) -> Sequence[_PLOT_OUT_TYPE]:
- """Plot a single or multiple values from the metric.
- All tasks' results are plotted on individual axes.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- axes: Sequence of matplotlib axis objects. If provided, will add the plots to the provided axis objects.
- If not provided, will create them.
- Returns:
- Sequence of tuples with Figure and Axes object for each task.
- .. plot::
- :scale: 75
- >>> # Example plotting a single value
- >>> import torch
- >>> from torchmetrics.wrappers import MultitaskWrapper
- >>> from torchmetrics.regression import MeanSquaredError
- >>> from torchmetrics.classification import BinaryAccuracy
- >>>
- >>> classification_target = torch.tensor([0, 1, 0])
- >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
- >>> targets = {"Classification": classification_target, "Regression": regression_target}
- >>>
- >>> classification_preds = torch.tensor([0, 0, 1])
- >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
- >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
- >>>
- >>> metrics = MultitaskWrapper({
- ... "Classification": BinaryAccuracy(),
- ... "Regression": MeanSquaredError()
- ... })
- >>> metrics.update(preds, targets)
- >>> value = metrics.compute()
- >>> fig_, ax_ = metrics.plot(value)
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> import torch
- >>> from torchmetrics.wrappers import MultitaskWrapper
- >>> from torchmetrics.regression import MeanSquaredError
- >>> from torchmetrics.classification import BinaryAccuracy
- >>>
- >>> classification_target = torch.tensor([0, 1, 0])
- >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
- >>> targets = {"Classification": classification_target, "Regression": regression_target}
- >>>
- >>> classification_preds = torch.tensor([0, 0, 1])
- >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
- >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
- >>>
- >>> metrics = MultitaskWrapper({
- ... "Classification": BinaryAccuracy(),
- ... "Regression": MeanSquaredError()
- ... })
- >>> values = []
- >>> for _ in range(10):
- ... values.append(metrics(preds, targets))
- >>> fig_, ax_ = metrics.plot(values)
- """
- if axes is not None:
- if not isinstance(axes, Sequence):
- raise TypeError(f"Expected argument `axes` to be a Sequence. Found type(axes) = {type(axes)}")
- if not all(isinstance(ax, _AX_TYPE) for ax in axes):
- raise TypeError("Expected each ax in argument `axes` to be a matplotlib axis object")
- if len(axes) != len(self.task_metrics):
- raise ValueError(
- "Expected argument `axes` to be a Sequence of the same length as the number of tasks."
- f"Found len(axes) = {len(axes)} and {len(self.task_metrics)} tasks"
- )
- val = val if val is not None else self.compute()
- fig_axs = []
- for i, (task_name, task_metric) in enumerate(self.task_metrics.items()):
- ax = axes[i] if axes is not None else None
- if not isinstance(task_metric, (Metric, MetricCollection)):
- raise TypeError(
- "Expected each task's metric to be a Metric or a MetricCollection. "
- f"Found a metric of type {type(task_metric)}"
- )
- if isinstance(val, dict):
- f, a = task_metric.plot(val[task_name], ax=ax)
- elif isinstance(val, Sequence):
- f, a = task_metric.plot([v[task_name] for v in val], ax=ax)
- else:
- raise TypeError(
- "Expected argument `val` to be None or of type Dict or Sequence[Dict]. "
- f"Found type(val)= {type(val)}"
- )
- fig_axs.append((f, a))
- return fig_axs
|