srmr.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Note: without special mention, the functions in this file are mainly translated from
  15. # the SRMRpy package for batched processing with pytorch
  16. from functools import lru_cache
  17. from math import ceil, pi
  18. from typing import Optional
  19. import torch
  20. from torch import Tensor
  21. from torch.nn.functional import pad
  22. from torchmetrics.utilities import rank_zero_warn
  23. from torchmetrics.utilities.imports import (
  24. _GAMMATONE_AVAILABLE,
  25. _TORCHAUDIO_AVAILABLE,
  26. )
  27. if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
  28. __doctest_skip__ = ["speech_reverberation_modulation_energy_ratio"]
  29. @lru_cache(maxsize=100)
  30. def _calc_erbs(low_freq: float, fs: int, n_filters: int, device: torch.device) -> Tensor:
  31. from gammatone.filters import centre_freqs
  32. ear_q = 9.26449 # Glasberg and Moore Parameters
  33. min_bw = 24.7
  34. order = 1
  35. erbs = ((centre_freqs(fs, n_filters, low_freq) / ear_q) ** order + min_bw**order) ** (1 / order)
  36. return torch.tensor(erbs, device=device)
  37. @lru_cache(maxsize=100)
  38. def _make_erb_filters(fs: int, num_freqs: int, cutoff: float, device: torch.device) -> Tensor:
  39. from gammatone.filters import centre_freqs, make_erb_filters
  40. cfs = centre_freqs(fs, num_freqs, cutoff)
  41. fcoefs = make_erb_filters(fs, cfs)
  42. return torch.tensor(fcoefs, device=device)
  43. @lru_cache(maxsize=100)
  44. def _compute_modulation_filterbank_and_cutoffs(
  45. min_cf: float, max_cf: float, n: int, fs: float, q: int, device: torch.device
  46. ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
  47. # this function is translated from the SRMRpy packaged
  48. spacing_factor = (max_cf / min_cf) ** (1.0 / (n - 1))
  49. cfs = torch.zeros(n, dtype=torch.float64)
  50. cfs[0] = min_cf
  51. for k in range(1, n):
  52. cfs[k] = cfs[k - 1] * spacing_factor
  53. def _make_modulation_filter(w0: Tensor, q: int) -> Tensor:
  54. w0 = torch.tan(w0 / 2)
  55. b0 = w0 / q
  56. b = torch.tensor([b0, 0, -b0], dtype=torch.float64)
  57. a = torch.tensor([(1 + b0 + w0**2), (2 * w0**2 - 2), (1 - b0 + w0**2)], dtype=torch.float64)
  58. return torch.stack([b, a], dim=0)
  59. mfb = torch.stack([_make_modulation_filter(w0, q) for w0 in 2 * pi * cfs / fs], dim=0)
  60. def _calc_cutoffs(cfs: Tensor, fs: float, q: int) -> tuple[Tensor, Tensor]:
  61. # Calculates cutoff frequencies (3 dB) for 2nd order bandpass
  62. w0 = 2 * pi * cfs / fs
  63. b0 = torch.tan(w0 / 2) / q
  64. ll = cfs - (b0 * fs / (2 * pi))
  65. rr = cfs + (b0 * fs / (2 * pi))
  66. return ll, rr
  67. cfs = cfs.to(device=device)
  68. mfb = mfb.to(device=device)
  69. ll, rr = _calc_cutoffs(cfs, fs, q)
  70. return cfs, mfb, ll, rr
  71. def _hilbert(x: Tensor, n: Optional[int] = None) -> Tensor:
  72. if x.is_complex():
  73. raise ValueError("x must be real.")
  74. if n is None:
  75. n = x.shape[-1]
  76. # Make N multiple of 16 to make sure the transform will be fast
  77. if n % 16:
  78. n = ceil(n / 16) * 16
  79. if n <= 0:
  80. raise ValueError("N must be positive.")
  81. x_fft = torch.fft.fft(x, n=n, dim=-1)
  82. h = torch.zeros(n, dtype=x.dtype, device=x.device, requires_grad=False)
  83. if n % 2 == 0:
  84. h[0] = h[n // 2] = 1
  85. h[1 : n // 2] = 2
  86. else:
  87. h[0] = 1
  88. h[1 : (n + 1) // 2] = 2
  89. y = torch.fft.ifft(x_fft * h, dim=-1)
  90. return y[..., : x.shape[-1]]
  91. def _erb_filterbank(wave: Tensor, coefs: Tensor) -> Tensor:
  92. """Translated from gammatone package.
  93. Args:
  94. wave: shape [B, time]
  95. coefs: shape [N, 10]
  96. Returns:
  97. Tensor: shape [B, N, time]
  98. """
  99. from torchaudio.functional.filtering import lfilter
  100. num_batch, time = wave.shape
  101. wave = wave.to(dtype=coefs.dtype).reshape(num_batch, 1, time) # [B, time]
  102. wave = wave.expand(-1, coefs.shape[0], -1) # [B, N, time]
  103. gain = coefs[:, 9]
  104. as1 = coefs[:, (0, 1, 5)] # A0, A11, A2
  105. as2 = coefs[:, (0, 2, 5)] # A0, A12, A2
  106. as3 = coefs[:, (0, 3, 5)] # A0, A13, A2
  107. as4 = coefs[:, (0, 4, 5)] # A0, A14, A2
  108. bs = coefs[:, 6:9] # B0, B1, B2
  109. y1 = lfilter(wave, bs, as1, batching=True)
  110. y2 = lfilter(y1, bs, as2, batching=True)
  111. y3 = lfilter(y2, bs, as3, batching=True)
  112. y4 = lfilter(y3, bs, as4, batching=True)
  113. return y4 / gain.reshape(1, -1, 1)
  114. def _normalize_energy(energy: Tensor, drange: float = 30.0) -> Tensor:
  115. """Normalize energy to a dynamic range of 30 dB.
  116. Args:
  117. energy: shape [B, N_filters, 8, n_frames]
  118. drange: dynamic range in dB
  119. """
  120. peak_energy = torch.mean(energy, dim=1, keepdim=True).max(dim=2, keepdim=True).values
  121. peak_energy = peak_energy.max(dim=3, keepdim=True).values
  122. min_energy = peak_energy * 10.0 ** (-drange / 10.0)
  123. energy = torch.where(energy < min_energy, min_energy, energy)
  124. return torch.where(energy > peak_energy, peak_energy, energy)
  125. def _cal_srmr_score(bw: Tensor, avg_energy: Tensor, cutoffs: Tensor) -> Tensor:
  126. """Calculate srmr score."""
  127. if (cutoffs[4] <= bw) and (cutoffs[5] > bw):
  128. kstar = 5
  129. elif (cutoffs[5] <= bw) and (cutoffs[6] > bw):
  130. kstar = 6
  131. elif (cutoffs[6] <= bw) and (cutoffs[7] > bw):
  132. kstar = 7
  133. elif cutoffs[7] <= bw:
  134. kstar = 8
  135. else:
  136. raise ValueError("Something wrong with the cutoffs compared to bw values.")
  137. return torch.sum(avg_energy[:, :4]) / torch.sum(avg_energy[:, 4:kstar])
  138. def speech_reverberation_modulation_energy_ratio(
  139. preds: Tensor,
  140. fs: int,
  141. n_cochlear_filters: int = 23,
  142. low_freq: float = 125,
  143. min_cf: float = 4,
  144. max_cf: Optional[float] = None,
  145. norm: bool = False,
  146. fast: bool = False,
  147. ) -> Tensor:
  148. """Calculate `Speech-to-Reverberation Modulation Energy Ratio`_ (SRMR).
  149. SRMR is a non-intrusive metric for speech quality and intelligibility based on
  150. a modulation spectral representation of the speech signal.
  151. This code is translated from SRMRToolbox and `SRMRpy`_.
  152. Args:
  153. preds: shape ``(..., time)``
  154. fs: the sampling rate
  155. n_cochlear_filters: Number of filters in the acoustic filterbank
  156. low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
  157. min_cf: Center frequency in Hz of the first modulation filter.
  158. max_cf: Center frequency in Hz of the last modulation filter. If None is given,
  159. then 30 Hz will be used for `norm==False`, otherwise 128 Hz will be used.
  160. norm: Use modulation spectrum energy normalization
  161. fast: Use the faster version based on the gammatonegram.
  162. Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch,
  163. setting `fast=True` may slow down the speed for calculating this metric on GPU.
  164. .. hint::
  165. Usingsing this metrics requires you to have ``gammatone`` and ``torchaudio`` installed.
  166. Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
  167. and ``pip install git+https://github.com/detly/gammatone``.
  168. .. attention::
  169. This implementation is experimental, and might not be consistent with the matlab
  170. implementation SRMRToolbox, especially the fast implementation.
  171. The slow versions, a) ``fast=False, norm=False, max_cf=128``, b) ``fast=False, norm=True, max_cf=30``,
  172. have a relatively small inconsistency.
  173. Returns:
  174. Scalar tensor with srmr value with shape ``(...)``
  175. Raises:
  176. ModuleNotFoundError:
  177. If ``gammatone`` or ``torchaudio`` package is not installed
  178. Example:
  179. >>> from torch import randn
  180. >>> from torchmetrics.functional.audio import speech_reverberation_modulation_energy_ratio
  181. >>> preds = randn(8000)
  182. >>> speech_reverberation_modulation_energy_ratio(preds, 8000)
  183. tensor([0.3191], dtype=torch.float64)
  184. """
  185. if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
  186. raise ModuleNotFoundError(
  187. "speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
  188. " `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
  189. "``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
  190. )
  191. from gammatone.fftweight import fft_gtgram
  192. from torchaudio.functional.filtering import lfilter
  193. _srmr_arg_validate(
  194. fs=fs,
  195. n_cochlear_filters=n_cochlear_filters,
  196. low_freq=low_freq,
  197. min_cf=min_cf,
  198. max_cf=max_cf,
  199. norm=norm,
  200. fast=fast,
  201. )
  202. shape = preds.shape
  203. preds = preds.reshape(1, -1) if len(shape) == 1 else preds.reshape(-1, shape[-1])
  204. num_batch, time = preds.shape
  205. # convert int type to float
  206. if not torch.is_floating_point(preds):
  207. preds = preds.to(torch.float64) / torch.finfo(preds.dtype).max
  208. # norm values in preds to [-1, 1], as lfilter requires an input in this range
  209. max_vals = preds.abs().max(dim=-1, keepdim=True).values
  210. val_norm = torch.where(
  211. max_vals > 1,
  212. max_vals,
  213. torch.tensor(1.0, dtype=max_vals.dtype, device=max_vals.device),
  214. )
  215. preds = preds / val_norm
  216. w_length_s = 0.256
  217. w_inc_s = 0.064
  218. # Computing gammatone envelopes
  219. if fast:
  220. rank_zero_warn("`fast=True` may slow down the speed of SRMR metric on GPU.")
  221. mfs = 400.0
  222. temp = []
  223. preds_np = preds.detach().cpu().numpy()
  224. for b in range(num_batch):
  225. gt_env_b = fft_gtgram(preds_np[b], fs, 0.010, 0.0025, n_cochlear_filters, low_freq)
  226. temp.append(torch.tensor(gt_env_b))
  227. gt_env = torch.stack(temp, dim=0).to(device=preds.device)
  228. else:
  229. fcoefs = _make_erb_filters(fs, n_cochlear_filters, low_freq, device=preds.device) # [N_filters, 10]
  230. gt_env = torch.abs(_hilbert(_erb_filterbank(preds, fcoefs))) # [B, N_filters, time]
  231. mfs = fs
  232. w_length = ceil(w_length_s * mfs)
  233. w_inc = ceil(w_inc_s * mfs)
  234. # Computing modulation filterbank with Q = 2 and 8 channels
  235. if max_cf is None:
  236. max_cf = 30 if norm else 128
  237. _, mf, cutoffs, _ = _compute_modulation_filterbank_and_cutoffs(
  238. min_cf, max_cf, n=8, fs=mfs, q=2, device=preds.device
  239. )
  240. num_frames = int(1 + (time - w_length) // w_inc)
  241. w = torch.hamming_window(w_length + 1, dtype=torch.float64, device=preds.device)[:-1]
  242. mod_out = lfilter(
  243. gt_env.unsqueeze(-2).expand(-1, -1, mf.shape[0], -1), mf[:, 1, :], mf[:, 0, :], clamp=False, batching=True
  244. ) # [B, N_filters, 8, time]
  245. # pad signal if it's shorter than window or it is not multiple of wInc
  246. padding = (0, max(ceil(time / w_inc) * w_inc - time, w_length - time))
  247. mod_out_pad = pad(mod_out, pad=padding, mode="constant", value=0)
  248. mod_out_frame = mod_out_pad.unfold(-1, w_length, w_inc)
  249. energy = ((mod_out_frame[..., :num_frames, :] * w) ** 2).sum(dim=-1) # [B, N_filters, 8, n_frames]
  250. if norm:
  251. energy = _normalize_energy(energy)
  252. erbs = torch.flipud(_calc_erbs(low_freq, fs, n_cochlear_filters, device=preds.device))
  253. avg_energy = torch.mean(energy, dim=-1)
  254. total_energy = torch.sum(avg_energy.reshape(num_batch, -1), dim=-1)
  255. ac_energy = torch.sum(avg_energy, dim=2)
  256. ac_perc = ac_energy * 100 / total_energy.reshape(-1, 1)
  257. ac_perc_cumsum = ac_perc.flip(-1).cumsum(-1)
  258. k90perc_idx = torch.nonzero((ac_perc_cumsum > 90).cumsum(-1) == 1)[:, 1]
  259. bw = erbs[k90perc_idx]
  260. temp = []
  261. for b in range(num_batch):
  262. score = _cal_srmr_score(bw[b], avg_energy[b], cutoffs=cutoffs)
  263. temp.append(score)
  264. score = torch.stack(temp)
  265. return score.reshape(*shape[:-1]) if len(shape) > 1 else score # recover original shape
  266. def _srmr_arg_validate(
  267. fs: int,
  268. n_cochlear_filters: int = 23,
  269. low_freq: float = 125,
  270. min_cf: float = 4,
  271. max_cf: Optional[float] = 128,
  272. norm: bool = False,
  273. fast: bool = False,
  274. ) -> None:
  275. """Validate the arguments for speech_reverberation_modulation_energy_ratio.
  276. Args:
  277. fs: the sampling rate
  278. n_cochlear_filters: Number of filters in the acoustic filterbank
  279. low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
  280. min_cf: Center frequency in Hz of the first modulation filter.
  281. max_cf: Center frequency in Hz of the last modulation filter. If None is given,
  282. norm: Use modulation spectrum energy normalization
  283. fast: Use the faster version based on the gammatonegram.
  284. """
  285. if not (isinstance(fs, int) and fs > 0):
  286. raise ValueError(f"Expected argument `fs` to be an int larger than 0, but got {fs}")
  287. if not (isinstance(n_cochlear_filters, int) and n_cochlear_filters > 0):
  288. raise ValueError(
  289. f"Expected argument `n_cochlear_filters` to be an int larger than 0, but got {n_cochlear_filters}"
  290. )
  291. if not ((isinstance(low_freq, (float, int))) and low_freq > 0):
  292. raise ValueError(f"Expected argument `low_freq` to be a float larger than 0, but got {low_freq}")
  293. if not ((isinstance(min_cf, (float, int))) and min_cf > 0):
  294. raise ValueError(f"Expected argument `min_cf` to be a float larger than 0, but got {min_cf}")
  295. if max_cf is not None and not ((isinstance(max_cf, (float, int))) and max_cf > 0):
  296. raise ValueError(f"Expected argument `max_cf` to be a float larger than 0, but got {max_cf}")
  297. if not isinstance(norm, bool):
  298. raise ValueError("Expected argument `norm` to be a bool value")
  299. if not isinstance(fast, bool):
  300. raise ValueError("Expected argument `fast` to be a bool value")