multitask.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # this is just a bypass for this module name collision with built-in one
  15. from collections.abc import Iterable, Sequence
  16. from copy import deepcopy
  17. from typing import Any, Optional, Union
  18. from torch import Tensor, nn
  19. from torchmetrics.collections import MetricCollection
  20. from torchmetrics.metric import Metric
  21. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  22. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  23. from torchmetrics.wrappers.abstract import WrapperMetric
  24. if not _MATPLOTLIB_AVAILABLE:
  25. __doctest_skip__ = ["MultitaskWrapper.plot"]
  26. class MultitaskWrapper(WrapperMetric):
  27. """Wrapper class for computing different metrics on different tasks in the context of multitask learning.
  28. In multitask learning the different tasks requires different metrics to be evaluated. This wrapper allows
  29. for easy evaluation in such cases by supporting multiple predictions and targets through a dictionary.
  30. Note that only metrics where the signature of `update` follows the standard `preds, target` is supported.
  31. Args:
  32. task_metrics:
  33. Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the
  34. names of the tasks, and the values represent the metrics to use for each task.
  35. prefix:
  36. A string to append in front of the metric keys. If not provided, will default to an empty string.
  37. postfix:
  38. A string to append after the keys of the output dict. If not provided, will default to an empty string.
  39. .. tip::
  40. The use prefix and postfix allows for easily creating task wrappers for training, validation and test.
  41. The arguments are only changing the output keys of the computed metrics and not the input keys. This means
  42. that a ``MultitaskWrapper`` initialized as ``MultitaskWrapper({"task": Metric()}, prefix="train_")`` will
  43. still expect the input to be a dictionary with the key "task", but the output will be a dictionary with the key
  44. "train_task".
  45. Raises:
  46. TypeError:
  47. If argument `task_metrics` is not an dictionary
  48. TypeError:
  49. If not all values in the `task_metrics` dictionary is instances of `Metric` or `MetricCollection`
  50. ValueError:
  51. If `prefix` is not a string
  52. ValueError:
  53. If `postfix` is not a string
  54. Example (with a single metric per class):
  55. >>> import torch
  56. >>> from torchmetrics.wrappers import MultitaskWrapper
  57. >>> from torchmetrics.regression import MeanSquaredError
  58. >>> from torchmetrics.classification import BinaryAccuracy
  59. >>>
  60. >>> classification_target = torch.tensor([0, 1, 0])
  61. >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
  62. >>> targets = {"Classification": classification_target, "Regression": regression_target}
  63. >>>
  64. >>> classification_preds = torch.tensor([0, 0, 1])
  65. >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
  66. >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
  67. >>>
  68. >>> metrics = MultitaskWrapper({
  69. ... "Classification": BinaryAccuracy(),
  70. ... "Regression": MeanSquaredError()
  71. ... })
  72. >>> metrics.update(preds, targets)
  73. >>> metrics.compute()
  74. {'Classification': tensor(0.3333), 'Regression': tensor(0.8333)}
  75. Example (with several metrics per task):
  76. >>> import torch
  77. >>> from torchmetrics import MetricCollection
  78. >>> from torchmetrics.wrappers import MultitaskWrapper
  79. >>> from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
  80. >>> from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
  81. >>>
  82. >>> classification_target = torch.tensor([0, 1, 0])
  83. >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
  84. >>> targets = {"Classification": classification_target, "Regression": regression_target}
  85. >>>
  86. >>> classification_preds = torch.tensor([0, 0, 1])
  87. >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
  88. >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
  89. >>>
  90. >>> metrics = MultitaskWrapper({
  91. ... "Classification": MetricCollection(BinaryAccuracy(), BinaryF1Score()),
  92. ... "Regression": MetricCollection(MeanSquaredError(), MeanAbsoluteError())
  93. ... })
  94. >>> metrics.update(preds, targets)
  95. >>> metrics.compute()
  96. {'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)},
  97. 'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}}
  98. Example (with a prefix and postfix):
  99. >>> import torch
  100. >>> from torchmetrics.wrappers import MultitaskWrapper
  101. >>> from torchmetrics.regression import MeanSquaredError
  102. >>> from torchmetrics.classification import BinaryAccuracy
  103. >>>
  104. >>> classification_target = torch.tensor([0, 1, 0])
  105. >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
  106. >>> targets = {"Classification": classification_target, "Regression": regression_target}
  107. >>> classification_preds = torch.tensor([0, 0, 1])
  108. >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
  109. >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
  110. >>>
  111. >>> metrics = MultitaskWrapper({
  112. ... "Classification": BinaryAccuracy(),
  113. ... "Regression": MeanSquaredError()
  114. ... }, prefix="train_")
  115. >>> metrics.update(preds, targets)
  116. >>> metrics.compute()
  117. {'train_Classification': tensor(0.3333), 'train_Regression': tensor(0.8333)}
  118. """
  119. is_differentiable: bool = False
  120. def __init__(
  121. self,
  122. task_metrics: dict[str, Union[Metric, MetricCollection]],
  123. prefix: Optional[str] = None,
  124. postfix: Optional[str] = None,
  125. ) -> None:
  126. super().__init__()
  127. if not isinstance(task_metrics, dict):
  128. raise TypeError(f"Expected argument `task_metrics` to be a dict. Found task_metrics = {task_metrics}")
  129. for metric in task_metrics.values():
  130. if not (isinstance(metric, (Metric, MetricCollection))):
  131. raise TypeError(
  132. "Expected each task's metric to be a Metric or a MetricCollection. "
  133. f"Found a metric of type {type(metric)}"
  134. )
  135. self.task_metrics = nn.ModuleDict(task_metrics)
  136. if prefix is not None and not isinstance(prefix, str):
  137. raise ValueError(f"Expected argument `prefix` to either be `None` or a string but got {prefix}")
  138. self._prefix = prefix or ""
  139. if postfix is not None and not isinstance(postfix, str):
  140. raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}")
  141. self._postfix = postfix or ""
  142. def items(self, flatten: bool = True) -> Iterable[tuple[str, nn.Module]]:
  143. """Iterate over task and task metrics.
  144. Args:
  145. flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
  146. If False, will iterate over the task names and the corresponding metrics.
  147. """
  148. for task_name, metric in self.task_metrics.items():
  149. if flatten and isinstance(metric, MetricCollection):
  150. for sub_metric_name, sub_metric in metric.items():
  151. yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}", sub_metric
  152. else:
  153. yield f"{self._prefix}{task_name}{self._postfix}", metric
  154. def keys(self, flatten: bool = True) -> Iterable[str]:
  155. """Iterate over task names.
  156. Args:
  157. flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
  158. If False, will iterate over the task names and the corresponding metrics.
  159. """
  160. for task_name, metric in self.task_metrics.items():
  161. if flatten and isinstance(metric, MetricCollection):
  162. for sub_metric_name in metric:
  163. yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}"
  164. else:
  165. yield f"{self._prefix}{task_name}{self._postfix}"
  166. def values(self, flatten: bool = True) -> Iterable[nn.Module]:
  167. """Iterate over task metrics.
  168. Args:
  169. flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
  170. If False, will iterate over the task names and the corresponding metrics.
  171. """
  172. for metric in self.task_metrics.values():
  173. if flatten and isinstance(metric, MetricCollection):
  174. yield from metric.values()
  175. else:
  176. yield metric
  177. def update(self, task_preds: dict[str, Any], task_targets: dict[str, Any]) -> None:
  178. """Update each task's metric with its corresponding pred and target.
  179. Args:
  180. task_preds: Dictionary associating each task to a Tensor of pred.
  181. task_targets: Dictionary associating each task to a Tensor of target.
  182. """
  183. if not self.task_metrics.keys() == task_preds.keys() == task_targets.keys():
  184. raise ValueError(
  185. "Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`"
  186. f". Found task_preds.keys() = {task_preds.keys()}, task_targets.keys() = {task_targets.keys()} "
  187. f"and self.task_metrics.keys() = {self.task_metrics.keys()}"
  188. )
  189. for task_name, metric in self.task_metrics.items():
  190. pred = task_preds[task_name]
  191. target = task_targets[task_name]
  192. if not (isinstance(metric, (Metric, MetricCollection))):
  193. raise TypeError(
  194. "Expected each task's metric to be a Metric or a MetricCollection. "
  195. f"Found a metric of type {type(metric)}"
  196. )
  197. metric.update(pred, target)
  198. def _convert_output(self, output: dict[str, Any]) -> dict[str, Any]:
  199. """Convert the output of the underlying metrics to a dictionary with the task names as keys."""
  200. return {f"{self._prefix}{task_name}{self._postfix}": task_output for task_name, task_output in output.items()}
  201. def compute(self) -> dict[str, Any]:
  202. """Compute metrics for all tasks."""
  203. output: dict[str, Any] = {}
  204. for task_name, metric in self.task_metrics.items():
  205. if not isinstance(metric, (Metric, MetricCollection)):
  206. raise TypeError(
  207. "Expected each task's metric to be a Metric or a MetricCollection. "
  208. f"Found a metric of type {type(metric)}"
  209. )
  210. output[task_name] = metric.compute()
  211. return self._convert_output(output)
  212. def forward(self, task_preds: dict[str, Tensor], task_targets: dict[str, Tensor]) -> dict[str, Any]:
  213. """Call underlying forward methods for all tasks and return the result as a dictionary."""
  214. # This method is overridden because we do not need the complex version defined in Metric, that relies on the
  215. # value of full_state_update, and that also accumulates the results. Here, all computations are handled by the
  216. # underlying metrics, which all have their own value of full_state_update, and which all accumulate the results
  217. # by themselves.
  218. return self._convert_output({
  219. task_name: metric(task_preds[task_name], task_targets[task_name])
  220. for task_name, metric in self.task_metrics.items()
  221. })
  222. def reset(self) -> None:
  223. """Reset all underlying metrics."""
  224. for metric in self.task_metrics.values():
  225. if not isinstance(metric, (Metric, MetricCollection)):
  226. raise TypeError(
  227. "Expected each task's metric to be a Metric or a MetricCollection. "
  228. f"Found a metric of type {type(metric)}"
  229. )
  230. metric.reset()
  231. super().reset()
  232. @staticmethod
  233. def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
  234. if arg is None or isinstance(arg, str):
  235. return arg
  236. raise ValueError(f"Expected input `{name}` to be a string, but got {type(arg)}")
  237. def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MultitaskWrapper":
  238. """Make a copy of the metric.
  239. Args:
  240. prefix: a string to append in front of the metric keys
  241. postfix: a string to append after the keys of the output dict.
  242. """
  243. multitask_copy = deepcopy(self)
  244. multitask_copy._prefix = self._check_arg(prefix, "prefix") or ""
  245. multitask_copy._postfix = self._check_arg(postfix, "prefix") or ""
  246. return multitask_copy
  247. def plot(
  248. self, val: Optional[Union[dict, Sequence[dict]]] = None, axes: Optional[Sequence[_AX_TYPE]] = None
  249. ) -> Sequence[_PLOT_OUT_TYPE]:
  250. """Plot a single or multiple values from the metric.
  251. All tasks' results are plotted on individual axes.
  252. Args:
  253. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  254. If no value is provided, will automatically call `metric.compute` and plot that result.
  255. axes: Sequence of matplotlib axis objects. If provided, will add the plots to the provided axis objects.
  256. If not provided, will create them.
  257. Returns:
  258. Sequence of tuples with Figure and Axes object for each task.
  259. .. plot::
  260. :scale: 75
  261. >>> # Example plotting a single value
  262. >>> import torch
  263. >>> from torchmetrics.wrappers import MultitaskWrapper
  264. >>> from torchmetrics.regression import MeanSquaredError
  265. >>> from torchmetrics.classification import BinaryAccuracy
  266. >>>
  267. >>> classification_target = torch.tensor([0, 1, 0])
  268. >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
  269. >>> targets = {"Classification": classification_target, "Regression": regression_target}
  270. >>>
  271. >>> classification_preds = torch.tensor([0, 0, 1])
  272. >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
  273. >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
  274. >>>
  275. >>> metrics = MultitaskWrapper({
  276. ... "Classification": BinaryAccuracy(),
  277. ... "Regression": MeanSquaredError()
  278. ... })
  279. >>> metrics.update(preds, targets)
  280. >>> value = metrics.compute()
  281. >>> fig_, ax_ = metrics.plot(value)
  282. .. plot::
  283. :scale: 75
  284. >>> # Example plotting multiple values
  285. >>> import torch
  286. >>> from torchmetrics.wrappers import MultitaskWrapper
  287. >>> from torchmetrics.regression import MeanSquaredError
  288. >>> from torchmetrics.classification import BinaryAccuracy
  289. >>>
  290. >>> classification_target = torch.tensor([0, 1, 0])
  291. >>> regression_target = torch.tensor([2.5, 5.0, 4.0])
  292. >>> targets = {"Classification": classification_target, "Regression": regression_target}
  293. >>>
  294. >>> classification_preds = torch.tensor([0, 0, 1])
  295. >>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
  296. >>> preds = {"Classification": classification_preds, "Regression": regression_preds}
  297. >>>
  298. >>> metrics = MultitaskWrapper({
  299. ... "Classification": BinaryAccuracy(),
  300. ... "Regression": MeanSquaredError()
  301. ... })
  302. >>> values = []
  303. >>> for _ in range(10):
  304. ... values.append(metrics(preds, targets))
  305. >>> fig_, ax_ = metrics.plot(values)
  306. """
  307. if axes is not None:
  308. if not isinstance(axes, Sequence):
  309. raise TypeError(f"Expected argument `axes` to be a Sequence. Found type(axes) = {type(axes)}")
  310. if not all(isinstance(ax, _AX_TYPE) for ax in axes):
  311. raise TypeError("Expected each ax in argument `axes` to be a matplotlib axis object")
  312. if len(axes) != len(self.task_metrics):
  313. raise ValueError(
  314. "Expected argument `axes` to be a Sequence of the same length as the number of tasks."
  315. f"Found len(axes) = {len(axes)} and {len(self.task_metrics)} tasks"
  316. )
  317. val = val if val is not None else self.compute()
  318. fig_axs = []
  319. for i, (task_name, task_metric) in enumerate(self.task_metrics.items()):
  320. ax = axes[i] if axes is not None else None
  321. if not isinstance(task_metric, (Metric, MetricCollection)):
  322. raise TypeError(
  323. "Expected each task's metric to be a Metric or a MetricCollection. "
  324. f"Found a metric of type {type(task_metric)}"
  325. )
  326. if isinstance(val, dict):
  327. f, a = task_metric.plot(val[task_name], ax=ax)
  328. elif isinstance(val, Sequence):
  329. f, a = task_metric.plot([v[task_name] for v in val], ax=ax)
  330. else:
  331. raise TypeError(
  332. "Expected argument `val` to be None or of type Dict or Sequence[Dict]. "
  333. f"Found type(val)= {type(val)}"
  334. )
  335. fig_axs.append((f, a))
  336. return fig_axs