snr.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor, tensor
  17. from torchmetrics.functional.audio.snr import (
  18. complex_scale_invariant_signal_noise_ratio,
  19. scale_invariant_signal_noise_ratio,
  20. signal_noise_ratio,
  21. )
  22. from torchmetrics.metric import Metric
  23. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  24. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  25. if not _MATPLOTLIB_AVAILABLE:
  26. __doctest_skip__ = [
  27. "SignalNoiseRatio.plot",
  28. "ScaleInvariantSignalNoiseRatio.plot",
  29. "ComplexScaleInvariantSignalNoiseRatio.plot",
  30. ]
  31. class SignalNoiseRatio(Metric):
  32. r"""Calculate `Signal-to-noise ratio`_ (SNR_) meric for evaluating quality of audio.
  33. .. math::
  34. \text{SNR} = \frac{P_{signal}}{P_{noise}}
  35. where :math:`P` denotes the power of each signal. The SNR metric compares the level of the desired signal to
  36. the level of background noise. Therefore, a high value of SNR means that the audio is clear.
  37. As input to `forward` and `update` the metric accepts the following input
  38. - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  39. - ``target`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  40. As output of `forward` and `compute` the metric returns the following output
  41. - ``snr`` (:class:`~torch.Tensor`): float scalar tensor with average SNR value over samples
  42. Args:
  43. zero_mean: if to zero mean target and preds or not
  44. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  45. Raises:
  46. TypeError:
  47. if target and preds have a different shape
  48. Example:
  49. >>> from torch import tensor
  50. >>> from torchmetrics.audio import SignalNoiseRatio
  51. >>> target = tensor([3.0, -0.5, 2.0, 7.0])
  52. >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
  53. >>> snr = SignalNoiseRatio()
  54. >>> snr(preds, target)
  55. tensor(16.1805)
  56. """
  57. full_state_update: bool = False
  58. is_differentiable: bool = True
  59. higher_is_better: bool = True
  60. sum_snr: Tensor
  61. total: Tensor
  62. plot_lower_bound: Optional[float] = None
  63. plot_upper_bound: Optional[float] = None
  64. def __init__(
  65. self,
  66. zero_mean: bool = False,
  67. **kwargs: Any,
  68. ) -> None:
  69. super().__init__(**kwargs)
  70. self.zero_mean = zero_mean
  71. self.add_state("sum_snr", default=tensor(0.0), dist_reduce_fx="sum")
  72. self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
  73. def update(self, preds: Tensor, target: Tensor) -> None:
  74. """Update state with predictions and targets."""
  75. snr_batch = signal_noise_ratio(preds=preds, target=target, zero_mean=self.zero_mean)
  76. self.sum_snr += snr_batch.sum()
  77. self.total += snr_batch.numel()
  78. def compute(self) -> Tensor:
  79. """Compute metric."""
  80. return self.sum_snr / self.total
  81. def plot(
  82. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  83. ) -> _PLOT_OUT_TYPE:
  84. """Plot a single or multiple values from the metric.
  85. Args:
  86. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  87. If no value is provided, will automatically call `metric.compute` and plot that result.
  88. ax: An matplotlib axis object. If provided will add plot to that axis
  89. Returns:
  90. Figure and Axes object
  91. Raises:
  92. ModuleNotFoundError:
  93. If `matplotlib` is not installed
  94. .. plot::
  95. :scale: 75
  96. >>> # Example plotting a single value
  97. >>> import torch
  98. >>> from torchmetrics.audio import SignalNoiseRatio
  99. >>> metric = SignalNoiseRatio()
  100. >>> metric.update(torch.rand(4), torch.rand(4))
  101. >>> fig_, ax_ = metric.plot()
  102. .. plot::
  103. :scale: 75
  104. >>> # Example plotting multiple values
  105. >>> import torch
  106. >>> from torchmetrics.audio import SignalNoiseRatio
  107. >>> metric = SignalNoiseRatio()
  108. >>> values = [ ]
  109. >>> for _ in range(10):
  110. ... values.append(metric(torch.rand(4), torch.rand(4)))
  111. >>> fig_, ax_ = metric.plot(values)
  112. """
  113. return self._plot(val, ax)
  114. class ScaleInvariantSignalNoiseRatio(Metric):
  115. """Calculate `Scale-invariant signal-to-noise ratio`_ (SI-SNR) metric for evaluating quality of audio.
  116. As input to `forward` and `update` the metric accepts the following input
  117. - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  118. - ``target`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  119. As output of `forward` and `compute` the metric returns the following output
  120. - ``si_snr`` (:class:`~torch.Tensor`): float scalar tensor with average SI-SNR value over samples
  121. Args:
  122. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  123. Raises:
  124. TypeError:
  125. if target and preds have a different shape
  126. Example:
  127. >>> import torch
  128. >>> from torch import tensor
  129. >>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
  130. >>> target = tensor([3.0, -0.5, 2.0, 7.0])
  131. >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
  132. >>> si_snr = ScaleInvariantSignalNoiseRatio()
  133. >>> si_snr(preds, target)
  134. tensor(15.0918)
  135. """
  136. is_differentiable = True
  137. sum_si_snr: Tensor
  138. total: Tensor
  139. higher_is_better = True
  140. plot_lower_bound: Optional[float] = None
  141. plot_upper_bound: Optional[float] = None
  142. def __init__(
  143. self,
  144. **kwargs: Any,
  145. ) -> None:
  146. super().__init__(**kwargs)
  147. self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum")
  148. self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
  149. def update(self, preds: Tensor, target: Tensor) -> None:
  150. """Update state with predictions and targets."""
  151. si_snr_batch = scale_invariant_signal_noise_ratio(preds=preds, target=target)
  152. self.sum_si_snr += si_snr_batch.sum()
  153. self.total += si_snr_batch.numel()
  154. def compute(self) -> Tensor:
  155. """Compute metric."""
  156. return self.sum_si_snr / self.total
  157. def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
  158. """Plot a single or multiple values from the metric.
  159. Args:
  160. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  161. If no value is provided, will automatically call `metric.compute` and plot that result.
  162. ax: An matplotlib axis object. If provided will add plot to that axis
  163. Returns:
  164. Figure and Axes object
  165. Raises:
  166. ModuleNotFoundError:
  167. If `matplotlib` is not installed
  168. .. plot::
  169. :scale: 75
  170. >>> # Example plotting a single value
  171. >>> import torch
  172. >>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
  173. >>> metric = ScaleInvariantSignalNoiseRatio()
  174. >>> metric.update(torch.rand(4), torch.rand(4))
  175. >>> fig_, ax_ = metric.plot()
  176. .. plot::
  177. :scale: 75
  178. >>> # Example plotting multiple values
  179. >>> import torch
  180. >>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
  181. >>> metric = ScaleInvariantSignalNoiseRatio()
  182. >>> values = [ ]
  183. >>> for _ in range(10):
  184. ... values.append(metric(torch.rand(4), torch.rand(4)))
  185. >>> fig_, ax_ = metric.plot(values)
  186. """
  187. return self._plot(val, ax)
  188. class ComplexScaleInvariantSignalNoiseRatio(Metric):
  189. """Calculate `Complex scale-invariant signal-to-noise ratio`_ (C-SI-SNR) metric for evaluating quality of audio.
  190. As input to `forward` and `update` the metric accepts the following input
  191. - ``preds`` (:class:`~torch.Tensor`): real float tensor with shape ``(...,frequency,time,2)`` or complex float
  192. tensor with shape ``(..., frequency,time)``
  193. - ``target`` (:class:`~torch.Tensor`): real float tensor with shape ``(...,frequency,time,2)`` or complex float
  194. tensor with shape ``(..., frequency,time)``
  195. As output of `forward` and `compute` the metric returns the following output
  196. - ``c_si_snr`` (:class:`~torch.Tensor`): float scalar tensor with average C-SI-SNR value over samples
  197. Args:
  198. zero_mean: if to zero mean target and preds or not
  199. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  200. Raises:
  201. ValueError:
  202. If ``zero_mean`` is not an bool
  203. TypeError:
  204. If ``preds`` is not the shape (..., frequency, time, 2) (after being converted to real if it is complex).
  205. If ``preds`` and ``target`` does not have the same shape.
  206. Example:
  207. >>> from torch import randn
  208. >>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
  209. >>> preds = randn((1,257,100,2))
  210. >>> target = randn((1,257,100,2))
  211. >>> c_si_snr = ComplexScaleInvariantSignalNoiseRatio()
  212. >>> c_si_snr(preds, target)
  213. tensor(-38.8832)
  214. """
  215. is_differentiable = True
  216. ci_snr_sum: Tensor
  217. num: Tensor
  218. higher_is_better = True
  219. plot_lower_bound: Optional[float] = None
  220. plot_upper_bound: Optional[float] = None
  221. def __init__(
  222. self,
  223. zero_mean: bool = False,
  224. **kwargs: Any,
  225. ) -> None:
  226. super().__init__(**kwargs)
  227. if not isinstance(zero_mean, bool):
  228. raise ValueError(f"Expected argument `zero_mean` to be an bool, but got {zero_mean}")
  229. self.zero_mean = zero_mean
  230. self.add_state("ci_snr_sum", default=tensor(0.0), dist_reduce_fx="sum")
  231. self.add_state("num", default=tensor(0), dist_reduce_fx="sum")
  232. def update(self, preds: Tensor, target: Tensor) -> None:
  233. """Update state with predictions and targets."""
  234. v = complex_scale_invariant_signal_noise_ratio(preds=preds, target=target, zero_mean=self.zero_mean)
  235. self.ci_snr_sum += v.sum()
  236. self.num += v.numel()
  237. def compute(self) -> Tensor:
  238. """Compute metric."""
  239. return self.ci_snr_sum / self.num
  240. def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
  241. """Plot a single or multiple values from the metric.
  242. Args:
  243. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  244. If no value is provided, will automatically call `metric.compute` and plot that result.
  245. ax: An matplotlib axis object. If provided will add plot to that axis
  246. Returns:
  247. Figure and Axes object
  248. Raises:
  249. ModuleNotFoundError:
  250. If `matplotlib` is not installed
  251. .. plot::
  252. :scale: 75
  253. >>> # Example plotting a single value
  254. >>> import torch
  255. >>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
  256. >>> metric = ComplexScaleInvariantSignalNoiseRatio()
  257. >>> metric.update(torch.rand(1,257,100,2), torch.rand(1,257,100,2))
  258. >>> fig_, ax_ = metric.plot()
  259. .. plot::
  260. :scale: 75
  261. >>> # Example plotting multiple values
  262. >>> import torch
  263. >>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
  264. >>> metric = ComplexScaleInvariantSignalNoiseRatio()
  265. >>> values = [ ]
  266. >>> for _ in range(10):
  267. ... values.append(metric(torch.rand(1,257,100,2), torch.rand(1,257,100,2)))
  268. >>> fig_, ax_ = metric.plot(values)
  269. """
  270. return self._plot(val, ax)