sdr.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import Optional
  16. import torch
  17. from torch import Tensor
  18. # import or def the norm/solve function
  19. from torch.linalg import norm
  20. from torchmetrics.utilities import rank_zero_warn
  21. from torchmetrics.utilities.checks import _check_same_shape
  22. from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE
  23. def _symmetric_toeplitz(vector: Tensor) -> Tensor:
  24. """Construct a symmetric Toeplitz matrix using one vector.
  25. Args:
  26. vector: shape [..., L]
  27. Example:
  28. >>> from torch import tensor
  29. >>> from torchmetrics.functional.audio.sdr import _symmetric_toeplitz
  30. >>> v = tensor([0, 1, 2, 3, 4])
  31. >>> _symmetric_toeplitz(v)
  32. tensor([[0, 1, 2, 3, 4],
  33. [1, 0, 1, 2, 3],
  34. [2, 1, 0, 1, 2],
  35. [3, 2, 1, 0, 1],
  36. [4, 3, 2, 1, 0]])
  37. Returns:
  38. a symmetric Toeplitz matrix of shape [..., L, L]
  39. """
  40. vec_exp = torch.cat([torch.flip(vector, dims=(-1,)), vector[..., 1:]], dim=-1)
  41. v_len = vector.shape[-1]
  42. return torch.as_strided(
  43. vec_exp, size=(*vec_exp.shape[:-1], v_len, v_len), stride=(*vec_exp.stride()[:-1], 1, 1)
  44. ).flip(dims=(-1,))
  45. def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) -> tuple[Tensor, Tensor]:
  46. r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds`.
  47. This calculation is done using the fast Fourier transform (FFT). Let's denotes the symmetric Toeplitz metric of the
  48. auto correlation of `target` as `R`, the cross correlation as 'b', then solving the equation `Rh=b` could have `h`
  49. as the coordinate of `preds` in the column space of the `corr_len` shifts of `target`.
  50. Args:
  51. target: the target (reference) signal of shape [..., time]
  52. preds: the preds (estimated) signal of shape [..., time]
  53. corr_len: the length of the auto correlation and cross correlation
  54. Returns:
  55. the auto correlation of `target` of shape [..., corr_len]
  56. the cross correlation of `target` and `preds` of shape [..., corr_len]
  57. """
  58. # the valid length for the signal after convolution
  59. n_fft = 2 ** math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1))
  60. # computes the auto correlation of `target`
  61. # r_0 is the first row of the symmetric Toeplitz metric
  62. t_fft = torch.fft.rfft(target, n=n_fft, dim=-1)
  63. r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len]
  64. # computes the cross-correlation of `target` and `preds`
  65. p_fft = torch.fft.rfft(preds, n=n_fft, dim=-1)
  66. b = torch.fft.irfft(t_fft.conj() * p_fft, n=n_fft, dim=-1)[..., :corr_len]
  67. return r_0, b
  68. def signal_distortion_ratio(
  69. preds: Tensor,
  70. target: Tensor,
  71. use_cg_iter: Optional[int] = None,
  72. filter_length: int = 512,
  73. zero_mean: bool = False,
  74. load_diag: Optional[float] = None,
  75. ) -> Tensor:
  76. r"""Calculate Signal to Distortion Ratio (SDR) metric. See `SDR ref1`_ and `SDR ref2`_ for details on the metric.
  77. .. note:
  78. The metric currently does not seem to work with Pytorch v1.11 and specific GPU hardware.
  79. Args:
  80. preds: float tensor with shape ``(...,time)``
  81. target: float tensor with shape ``(...,time)``
  82. use_cg_iter:
  83. If provided, conjugate gradient descent is used to solve for the distortion
  84. filter coefficients instead of direct Gaussian elimination, which requires that
  85. ``fast-bss-eval`` is installed and pytorch version >= 1.8.
  86. This can speed up the computation of the metrics in case the filters
  87. are long. Using a value of 10 here has been shown to provide
  88. good accuracy in most cases and is sufficient when using this
  89. loss to train neural separation networks.
  90. filter_length: The length of the distortion filter allowed
  91. zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics
  92. load_diag:
  93. If provided, this small value is added to the diagonal coefficients of
  94. the system metrics when solving for the filter coefficients.
  95. This can help stabilize the metric in the case where some reference signals may sometimes be zero
  96. Returns:
  97. Float tensor with shape ``(...,)`` of SDR values per sample
  98. Raises:
  99. RuntimeError:
  100. If ``preds`` and ``target`` does not have the same shape
  101. Example:
  102. >>> from torch import randn
  103. >>> from torchmetrics.functional.audio import signal_distortion_ratio
  104. >>> preds = randn(8000)
  105. >>> target = randn(8000)
  106. >>> signal_distortion_ratio(preds, target)
  107. tensor(-11.9930)
  108. >>> # use with permutation_invariant_training
  109. >>> from torchmetrics.functional.audio import permutation_invariant_training
  110. >>> preds = randn(4, 2, 8000) # [batch, spk, time]
  111. >>> target = randn(4, 2, 8000)
  112. >>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio)
  113. >>> best_metric
  114. tensor([-11.7748, -11.7948, -11.7160, -11.6254])
  115. >>> best_perm
  116. tensor([[1, 0],
  117. [1, 0],
  118. [1, 0],
  119. [0, 1]])
  120. """
  121. _check_same_shape(preds, target)
  122. # use double precision
  123. preds_dtype = preds.dtype
  124. preds = preds.double()
  125. target = target.double()
  126. if zero_mean:
  127. preds = preds - preds.mean(dim=-1, keepdim=True)
  128. target = target - target.mean(dim=-1, keepdim=True)
  129. # normalize along time-axis to make preds and target have unit norm
  130. target = target / torch.clamp(norm(target, dim=-1, keepdim=True), min=1e-6)
  131. preds = preds / torch.clamp(norm(preds, dim=-1, keepdim=True), min=1e-6)
  132. # solve for the optimal filter
  133. # compute auto-correlation and cross-correlation
  134. r_0, b = _compute_autocorr_crosscorr(target, preds, corr_len=filter_length)
  135. if load_diag is not None:
  136. # the diagonal factor of the Toeplitz matrix is the first coefficient of r_0
  137. r_0[..., 0] += load_diag
  138. if use_cg_iter is not None and _FAST_BSS_EVAL_AVAILABLE:
  139. from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient
  140. # use preconditioned conjugate gradient
  141. sol = toeplitz_conjugate_gradient(r_0, b, n_iter=use_cg_iter)
  142. else:
  143. if use_cg_iter is not None and not _FAST_BSS_EVAL_AVAILABLE:
  144. rank_zero_warn(
  145. "The `use_cg_iter` parameter of `SDR` requires that `fast-bss-eval` is installed. "
  146. "To make this this warning disappear, you could install `fast-bss-eval` using "
  147. "`pip install fast-bss-eval` or set `use_cg_iter=None`. For this time, the solver "
  148. "provided by Pytorch is used.",
  149. UserWarning,
  150. )
  151. # regular matrix solver
  152. r = _symmetric_toeplitz(r_0) # the auto-correlation of the L shifts of `target`
  153. sol = torch.linalg.solve(r, b)
  154. # compute the coherence
  155. coh = torch.einsum("...l,...l->...", b, sol)
  156. # transform to decibels
  157. ratio = coh / (1 - coh)
  158. val = 10.0 * torch.log10(ratio)
  159. if preds_dtype == torch.float64:
  160. return val
  161. return val.float()
  162. def scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
  163. """`Scale-invariant signal-to-distortion ratio`_ (SI-SDR).
  164. The SI-SDR value is in general considered an overall measure of how good a source sound.
  165. Args:
  166. preds: float tensor with shape ``(...,time)``
  167. target: float tensor with shape ``(...,time)``
  168. zero_mean: If to zero mean target and preds or not
  169. Returns:
  170. Float tensor with shape ``(...,)`` of SDR values per sample
  171. Raises:
  172. RuntimeError:
  173. If ``preds`` and ``target`` does not have the same shape
  174. Example:
  175. >>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio
  176. >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
  177. >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
  178. >>> scale_invariant_signal_distortion_ratio(preds, target)
  179. tensor(18.4030)
  180. """
  181. _check_same_shape(preds, target)
  182. eps = torch.finfo(preds.dtype).eps
  183. if zero_mean:
  184. target = target - torch.mean(target, dim=-1, keepdim=True)
  185. preds = preds - torch.mean(preds, dim=-1, keepdim=True)
  186. alpha = (torch.sum(preds * target, dim=-1, keepdim=True) + eps) / (torch.sum(target**2, dim=-1, keepdim=True) + eps)
  187. target_scaled = alpha * target
  188. noise = target_scaled - preds
  189. val = (torch.sum(target_scaled**2, dim=-1) + eps) / (torch.sum(noise**2, dim=-1) + eps)
  190. return 10 * torch.log10(val)
  191. def source_aggregated_signal_distortion_ratio(
  192. preds: Tensor,
  193. target: Tensor,
  194. scale_invariant: bool = True,
  195. zero_mean: bool = False,
  196. ) -> Tensor:
  197. """`Source-aggregated signal-to-distortion ratio`_ (SA-SDR).
  198. The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where
  199. one-speaker and multiple-speaker scenes coexist.
  200. Args:
  201. preds: float tensor with shape ``(..., spk, time)``
  202. target: float tensor with shape ``(..., spk, time)``
  203. scale_invariant: if True, scale the targets of different speakers with the same alpha
  204. zero_mean: If to zero mean target and preds or not
  205. Returns:
  206. SA-SDR with shape ``(...)``
  207. Example:
  208. >>> from torch import randn
  209. >>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
  210. >>> preds = randn(2, 8000) # [..., spk, time]
  211. >>> target = randn(2, 8000)
  212. >>> source_aggregated_signal_distortion_ratio(preds, target)
  213. tensor(-50.8171)
  214. >>> # use with permutation_invariant_training
  215. >>> from torchmetrics.functional.audio import permutation_invariant_training
  216. >>> preds = randn(4, 2, 8000) # [batch, spk, time]
  217. >>> target = randn(4, 2, 8000)
  218. >>> best_metric, best_perm = permutation_invariant_training(preds, target,
  219. ... source_aggregated_signal_distortion_ratio, mode="permutation-wise")
  220. >>> best_metric
  221. tensor([-42.6290, -44.3500, -34.7503, -54.1828])
  222. >>> best_perm
  223. tensor([[0, 1],
  224. [1, 0],
  225. [0, 1],
  226. [1, 0]])
  227. """
  228. _check_same_shape(preds, target)
  229. if preds.ndim < 2:
  230. raise RuntimeError(f"The preds and target should have the shape (..., spk, time), but {preds.shape} found")
  231. eps = torch.finfo(preds.dtype).eps
  232. if zero_mean:
  233. target = target - torch.mean(target, dim=-1, keepdim=True)
  234. preds = preds - torch.mean(preds, dim=-1, keepdim=True)
  235. if scale_invariant:
  236. # scale the targets of different speakers with the same alpha (shape [..., 1, 1])
  237. alpha = ((preds * target).sum(dim=-1, keepdim=True).sum(dim=-2, keepdim=True) + eps) / (
  238. (target**2).sum(dim=-1, keepdim=True).sum(dim=-2, keepdim=True) + eps
  239. )
  240. target = alpha * target
  241. distortion = target - preds
  242. val = ((target**2).sum(dim=-1).sum(dim=-1) + eps) / ((distortion**2).sum(dim=-1).sum(dim=-1) + eps)
  243. return 10 * torch.log10(val)