sdr.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor, tensor
  17. from torchmetrics.functional.audio.sdr import (
  18. scale_invariant_signal_distortion_ratio,
  19. signal_distortion_ratio,
  20. source_aggregated_signal_distortion_ratio,
  21. )
  22. from torchmetrics.metric import Metric
  23. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  24. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  25. __doctest_requires__ = {"SignalDistortionRatio": ["fast_bss_eval"]}
  26. if not _MATPLOTLIB_AVAILABLE:
  27. __doctest_skip__ = [
  28. "SignalDistortionRatio.plot",
  29. "ScaleInvariantSignalDistortionRatio.plot",
  30. "SourceAggregatedSignalDistortionRatio.plot",
  31. ]
  32. class SignalDistortionRatio(Metric):
  33. r"""Calculate Signal to Distortion Ratio (SDR) metric.
  34. See `SDR ref1`_ and `SDR ref2`_ for details on the metric.
  35. As input to ``forward`` and ``update`` the metric accepts the following input
  36. - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  37. - ``target`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  38. As output of `forward` and `compute` the metric returns the following output
  39. - ``sdr`` (:class:`~torch.Tensor`): float scalar tensor with average SDR value over samples
  40. .. note:
  41. The metric currently does not seem to work with Pytorch v1.11 and specific GPU hardware.
  42. Args:
  43. use_cg_iter:
  44. If provided, conjugate gradient descent is used to solve for the distortion
  45. filter coefficients instead of direct Gaussian elimination, which requires that
  46. ``fast-bss-eval`` is installed and pytorch version >= 1.8.
  47. This can speed up the computation of the metrics in case the filters
  48. are long. Using a value of 10 here has been shown to provide
  49. good accuracy in most cases and is sufficient when using this
  50. loss to train neural separation networks.
  51. filter_length: The length of the distortion filter allowed
  52. zero_mean:
  53. When set to True, the mean of all signals is subtracted prior to computation of the metrics
  54. load_diag:
  55. If provided, this small value is added to the diagonal coefficients of the system metrics when solving
  56. for the filter coefficients. This can help stabilize the metric in the case where some reference
  57. signals may sometimes be zero
  58. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  59. Example:
  60. >>> from torch import randn
  61. >>> from torchmetrics.audio import SignalDistortionRatio
  62. >>> preds = randn(8000)
  63. >>> target = randn(8000)
  64. >>> sdr = SignalDistortionRatio()
  65. >>> sdr(preds, target)
  66. tensor(-11.9930)
  67. >>> # use with pit
  68. >>> from torchmetrics.audio import PermutationInvariantTraining
  69. >>> from torchmetrics.functional.audio import signal_distortion_ratio
  70. >>> preds = randn(4, 2, 8000) # [batch, spk, time]
  71. >>> target = randn(4, 2, 8000)
  72. >>> pit = PermutationInvariantTraining(signal_distortion_ratio,
  73. ... mode="speaker-wise", eval_func="max")
  74. >>> pit(preds, target)
  75. tensor(-11.7277)
  76. """
  77. sum_sdr: Tensor
  78. total: Tensor
  79. full_state_update: bool = False
  80. is_differentiable: bool = True
  81. higher_is_better: bool = True
  82. plot_lower_bound: Optional[float] = None
  83. plot_upper_bound: Optional[float] = None
  84. def __init__(
  85. self,
  86. use_cg_iter: Optional[int] = None,
  87. filter_length: int = 512,
  88. zero_mean: bool = False,
  89. load_diag: Optional[float] = None,
  90. **kwargs: Any,
  91. ) -> None:
  92. super().__init__(**kwargs)
  93. self.use_cg_iter = use_cg_iter
  94. self.filter_length = filter_length
  95. self.zero_mean = zero_mean
  96. self.load_diag = load_diag
  97. self.add_state("sum_sdr", default=tensor(0.0), dist_reduce_fx="sum")
  98. self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
  99. def update(self, preds: Tensor, target: Tensor) -> None:
  100. """Update state with predictions and targets."""
  101. sdr_batch = signal_distortion_ratio(
  102. preds, target, self.use_cg_iter, self.filter_length, self.zero_mean, self.load_diag
  103. )
  104. self.sum_sdr += sdr_batch.sum()
  105. self.total += sdr_batch.numel()
  106. def compute(self) -> Tensor:
  107. """Compute metric."""
  108. return self.sum_sdr / self.total
  109. def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
  110. """Plot a single or multiple values from the metric.
  111. Args:
  112. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  113. If no value is provided, will automatically call `metric.compute` and plot that result.
  114. ax: An matplotlib axis object. If provided will add plot to that axis
  115. Returns:
  116. Figure and Axes object
  117. Raises:
  118. ModuleNotFoundError:
  119. If `matplotlib` is not installed
  120. .. plot::
  121. :scale: 75
  122. >>> # Example plotting a single value
  123. >>> import torch
  124. >>> from torchmetrics.audio import SignalDistortionRatio
  125. >>> metric = SignalDistortionRatio()
  126. >>> metric.update(torch.rand(8000), torch.rand(8000))
  127. >>> fig_, ax_ = metric.plot()
  128. .. plot::
  129. :scale: 75
  130. >>> # Example plotting multiple values
  131. >>> import torch
  132. >>> from torchmetrics.audio import SignalDistortionRatio
  133. >>> metric = SignalDistortionRatio()
  134. >>> values = [ ]
  135. >>> for _ in range(10):
  136. ... values.append(metric(torch.rand(8000), torch.rand(8000)))
  137. >>> fig_, ax_ = metric.plot(values)
  138. """
  139. return self._plot(val, ax)
  140. class ScaleInvariantSignalDistortionRatio(Metric):
  141. """`Scale-invariant signal-to-distortion ratio`_ (SI-SDR).
  142. The SI-SDR value is in general considered an overall measure of how good a source sound.
  143. As input to `forward` and `update` the metric accepts the following input
  144. - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  145. - ``target`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  146. As output of `forward` and `compute` the metric returns the following output
  147. - ``si_sdr`` (:class:`~torch.Tensor`): float scalar tensor with average SI-SDR value over samples
  148. Args:
  149. zero_mean: if to zero mean target and preds or not
  150. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  151. Raises:
  152. TypeError:
  153. if target and preds have a different shape
  154. Example:
  155. >>> from torch import tensor
  156. >>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
  157. >>> target = tensor([3.0, -0.5, 2.0, 7.0])
  158. >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
  159. >>> si_sdr = ScaleInvariantSignalDistortionRatio()
  160. >>> si_sdr(preds, target)
  161. tensor(18.4030)
  162. """
  163. is_differentiable = True
  164. higher_is_better = True
  165. sum_si_sdr: Tensor
  166. total: Tensor
  167. plot_lower_bound: Optional[float] = None
  168. plot_upper_bound: Optional[float] = None
  169. def __init__(
  170. self,
  171. zero_mean: bool = False,
  172. **kwargs: Any,
  173. ) -> None:
  174. super().__init__(**kwargs)
  175. self.zero_mean = zero_mean
  176. self.add_state("sum_si_sdr", default=tensor(0.0), dist_reduce_fx="sum")
  177. self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
  178. def update(self, preds: Tensor, target: Tensor) -> None:
  179. """Update state with predictions and targets."""
  180. si_sdr_batch = scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=self.zero_mean)
  181. self.sum_si_sdr += si_sdr_batch.sum()
  182. self.total += si_sdr_batch.numel()
  183. def compute(self) -> Tensor:
  184. """Compute metric."""
  185. return self.sum_si_sdr / self.total
  186. def plot(
  187. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  188. ) -> _PLOT_OUT_TYPE:
  189. """Plot a single or multiple values from the metric.
  190. Args:
  191. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  192. If no value is provided, will automatically call `metric.compute` and plot that result.
  193. ax: An matplotlib axis object. If provided will add plot to that axis
  194. Returns:
  195. Figure and Axes object
  196. Raises:
  197. ModuleNotFoundError:
  198. If `matplotlib` is not installed
  199. .. plot::
  200. :scale: 75
  201. >>> # Example plotting a single value
  202. >>> import torch
  203. >>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
  204. >>> target = torch.randn(5)
  205. >>> preds = torch.randn(5)
  206. >>> metric = ScaleInvariantSignalDistortionRatio()
  207. >>> metric.update(preds, target)
  208. >>> fig_, ax_ = metric.plot()
  209. .. plot::
  210. :scale: 75
  211. >>> # Example plotting multiple values
  212. >>> import torch
  213. >>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
  214. >>> target = torch.randn(5)
  215. >>> preds = torch.randn(5)
  216. >>> metric = ScaleInvariantSignalDistortionRatio()
  217. >>> values = [ ]
  218. >>> for _ in range(10):
  219. ... values.append(metric(preds, target))
  220. >>> fig_, ax_ = metric.plot(values)
  221. """
  222. return self._plot(val, ax)
  223. class SourceAggregatedSignalDistortionRatio(Metric):
  224. r"""`Source-aggregated signal-to-distortion ratio`_ (SA-SDR).
  225. The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where
  226. one-speaker and multiple-speaker scenes coexist.
  227. As input to ``forward`` and ``update`` the metric accepts the following input
  228. - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(..., spk, time)``
  229. - ``target`` (:class:`~torch.Tensor`): float tensor with shape ``(..., spk, time)``
  230. As output of `forward` and `compute` the metric returns the following output
  231. - ``sa_sdr`` (:class:`~torch.Tensor`): float scalar tensor with average SA-SDR value over samples
  232. Args:
  233. preds: float tensor with shape ``(..., spk, time)``
  234. target: float tensor with shape ``(..., spk, time)``
  235. scale_invariant: if True, scale the targets of different speakers with the same alpha
  236. zero_mean: If to zero mean target and preds or not
  237. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  238. Example:
  239. >>> from torch import randn
  240. >>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
  241. >>> preds = randn(2, 8000) # [..., spk, time]
  242. >>> target = randn(2, 8000)
  243. >>> sasdr = SourceAggregatedSignalDistortionRatio()
  244. >>> sasdr(preds, target)
  245. tensor(-50.8171)
  246. >>> # use with pit
  247. >>> from torchmetrics.audio import PermutationInvariantTraining
  248. >>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
  249. >>> preds = randn(4, 2, 8000) # [batch, spk, time]
  250. >>> target = randn(4, 2, 8000)
  251. >>> pit = PermutationInvariantTraining(source_aggregated_signal_distortion_ratio,
  252. ... mode="permutation-wise", eval_func="max")
  253. >>> pit(preds, target)
  254. tensor(-43.9780)
  255. """
  256. msum: Tensor
  257. mnum: Tensor
  258. full_state_update: bool = False
  259. is_differentiable: bool = True
  260. higher_is_better: bool = True
  261. plot_lower_bound: Optional[float] = None
  262. plot_upper_bound: Optional[float] = None
  263. def __init__(
  264. self,
  265. scale_invariant: bool = True,
  266. zero_mean: bool = False,
  267. **kwargs: Any,
  268. ) -> None:
  269. super().__init__(**kwargs)
  270. if not isinstance(scale_invariant, bool):
  271. raise ValueError(f"Expected argument `scale_invarint` to be a bool, but got {scale_invariant}")
  272. self.scale_invariant = scale_invariant
  273. if not isinstance(zero_mean, bool):
  274. raise ValueError(f"Expected argument `zero_mean` to be a bool, but got {zero_mean}")
  275. self.zero_mean = zero_mean
  276. self.add_state("msum", default=tensor(0.0), dist_reduce_fx="sum")
  277. self.add_state("mnum", default=tensor(0), dist_reduce_fx="sum")
  278. def update(self, preds: Tensor, target: Tensor) -> None:
  279. """Update state with predictions and targets."""
  280. mbatch = source_aggregated_signal_distortion_ratio(preds, target, self.scale_invariant, self.zero_mean)
  281. self.msum += mbatch.sum()
  282. self.mnum += mbatch.numel()
  283. def compute(self) -> Tensor:
  284. """Compute metric."""
  285. return self.msum / self.mnum
  286. def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
  287. """Plot a single or multiple values from the metric.
  288. Args:
  289. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  290. If no value is provided, will automatically call `metric.compute` and plot that result.
  291. ax: An matplotlib axis object. If provided will add plot to that axis
  292. Returns:
  293. Figure and Axes object
  294. Raises:
  295. ModuleNotFoundError:
  296. If `matplotlib` is not installed
  297. .. plot::
  298. :scale: 75
  299. >>> # Example plotting a single value
  300. >>> import torch
  301. >>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
  302. >>> metric = SourceAggregatedSignalDistortionRatio()
  303. >>> metric.update(torch.rand(2,8000), torch.rand(2,8000))
  304. >>> fig_, ax_ = metric.plot()
  305. .. plot::
  306. :scale: 75
  307. >>> # Example plotting multiple values
  308. >>> import torch
  309. >>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
  310. >>> metric = SourceAggregatedSignalDistortionRatio()
  311. >>> values = [ ]
  312. >>> for _ in range(10):
  313. ... values.append(metric(torch.rand(2,8000), torch.rand(2,8000)))
  314. >>> fig_, ax_ = metric.plot(values)
  315. """
  316. return self._plot(val, ax)