| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Note: without special mention, the functions in this file are mainly translated from
- # the SRMRpy package for batched processing with pytorch
- from functools import lru_cache
- from math import ceil, pi
- from typing import Optional
- import torch
- from torch import Tensor
- from torch.nn.functional import pad
- from torchmetrics.utilities import rank_zero_warn
- from torchmetrics.utilities.imports import (
- _GAMMATONE_AVAILABLE,
- _TORCHAUDIO_AVAILABLE,
- )
- if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
- __doctest_skip__ = ["speech_reverberation_modulation_energy_ratio"]
- @lru_cache(maxsize=100)
- def _calc_erbs(low_freq: float, fs: int, n_filters: int, device: torch.device) -> Tensor:
- from gammatone.filters import centre_freqs
- ear_q = 9.26449 # Glasberg and Moore Parameters
- min_bw = 24.7
- order = 1
- erbs = ((centre_freqs(fs, n_filters, low_freq) / ear_q) ** order + min_bw**order) ** (1 / order)
- return torch.tensor(erbs, device=device)
- @lru_cache(maxsize=100)
- def _make_erb_filters(fs: int, num_freqs: int, cutoff: float, device: torch.device) -> Tensor:
- from gammatone.filters import centre_freqs, make_erb_filters
- cfs = centre_freqs(fs, num_freqs, cutoff)
- fcoefs = make_erb_filters(fs, cfs)
- return torch.tensor(fcoefs, device=device)
- @lru_cache(maxsize=100)
- def _compute_modulation_filterbank_and_cutoffs(
- min_cf: float, max_cf: float, n: int, fs: float, q: int, device: torch.device
- ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
- # this function is translated from the SRMRpy packaged
- spacing_factor = (max_cf / min_cf) ** (1.0 / (n - 1))
- cfs = torch.zeros(n, dtype=torch.float64)
- cfs[0] = min_cf
- for k in range(1, n):
- cfs[k] = cfs[k - 1] * spacing_factor
- def _make_modulation_filter(w0: Tensor, q: int) -> Tensor:
- w0 = torch.tan(w0 / 2)
- b0 = w0 / q
- b = torch.tensor([b0, 0, -b0], dtype=torch.float64)
- a = torch.tensor([(1 + b0 + w0**2), (2 * w0**2 - 2), (1 - b0 + w0**2)], dtype=torch.float64)
- return torch.stack([b, a], dim=0)
- mfb = torch.stack([_make_modulation_filter(w0, q) for w0 in 2 * pi * cfs / fs], dim=0)
- def _calc_cutoffs(cfs: Tensor, fs: float, q: int) -> tuple[Tensor, Tensor]:
- # Calculates cutoff frequencies (3 dB) for 2nd order bandpass
- w0 = 2 * pi * cfs / fs
- b0 = torch.tan(w0 / 2) / q
- ll = cfs - (b0 * fs / (2 * pi))
- rr = cfs + (b0 * fs / (2 * pi))
- return ll, rr
- cfs = cfs.to(device=device)
- mfb = mfb.to(device=device)
- ll, rr = _calc_cutoffs(cfs, fs, q)
- return cfs, mfb, ll, rr
- def _hilbert(x: Tensor, n: Optional[int] = None) -> Tensor:
- if x.is_complex():
- raise ValueError("x must be real.")
- if n is None:
- n = x.shape[-1]
- # Make N multiple of 16 to make sure the transform will be fast
- if n % 16:
- n = ceil(n / 16) * 16
- if n <= 0:
- raise ValueError("N must be positive.")
- x_fft = torch.fft.fft(x, n=n, dim=-1)
- h = torch.zeros(n, dtype=x.dtype, device=x.device, requires_grad=False)
- if n % 2 == 0:
- h[0] = h[n // 2] = 1
- h[1 : n // 2] = 2
- else:
- h[0] = 1
- h[1 : (n + 1) // 2] = 2
- y = torch.fft.ifft(x_fft * h, dim=-1)
- return y[..., : x.shape[-1]]
- def _erb_filterbank(wave: Tensor, coefs: Tensor) -> Tensor:
- """Translated from gammatone package.
- Args:
- wave: shape [B, time]
- coefs: shape [N, 10]
- Returns:
- Tensor: shape [B, N, time]
- """
- from torchaudio.functional.filtering import lfilter
- num_batch, time = wave.shape
- wave = wave.to(dtype=coefs.dtype).reshape(num_batch, 1, time) # [B, time]
- wave = wave.expand(-1, coefs.shape[0], -1) # [B, N, time]
- gain = coefs[:, 9]
- as1 = coefs[:, (0, 1, 5)] # A0, A11, A2
- as2 = coefs[:, (0, 2, 5)] # A0, A12, A2
- as3 = coefs[:, (0, 3, 5)] # A0, A13, A2
- as4 = coefs[:, (0, 4, 5)] # A0, A14, A2
- bs = coefs[:, 6:9] # B0, B1, B2
- y1 = lfilter(wave, bs, as1, batching=True)
- y2 = lfilter(y1, bs, as2, batching=True)
- y3 = lfilter(y2, bs, as3, batching=True)
- y4 = lfilter(y3, bs, as4, batching=True)
- return y4 / gain.reshape(1, -1, 1)
- def _normalize_energy(energy: Tensor, drange: float = 30.0) -> Tensor:
- """Normalize energy to a dynamic range of 30 dB.
- Args:
- energy: shape [B, N_filters, 8, n_frames]
- drange: dynamic range in dB
- """
- peak_energy = torch.mean(energy, dim=1, keepdim=True).max(dim=2, keepdim=True).values
- peak_energy = peak_energy.max(dim=3, keepdim=True).values
- min_energy = peak_energy * 10.0 ** (-drange / 10.0)
- energy = torch.where(energy < min_energy, min_energy, energy)
- return torch.where(energy > peak_energy, peak_energy, energy)
- def _cal_srmr_score(bw: Tensor, avg_energy: Tensor, cutoffs: Tensor) -> Tensor:
- """Calculate srmr score."""
- if (cutoffs[4] <= bw) and (cutoffs[5] > bw):
- kstar = 5
- elif (cutoffs[5] <= bw) and (cutoffs[6] > bw):
- kstar = 6
- elif (cutoffs[6] <= bw) and (cutoffs[7] > bw):
- kstar = 7
- elif cutoffs[7] <= bw:
- kstar = 8
- else:
- raise ValueError("Something wrong with the cutoffs compared to bw values.")
- return torch.sum(avg_energy[:, :4]) / torch.sum(avg_energy[:, 4:kstar])
- def speech_reverberation_modulation_energy_ratio(
- preds: Tensor,
- fs: int,
- n_cochlear_filters: int = 23,
- low_freq: float = 125,
- min_cf: float = 4,
- max_cf: Optional[float] = None,
- norm: bool = False,
- fast: bool = False,
- ) -> Tensor:
- """Calculate `Speech-to-Reverberation Modulation Energy Ratio`_ (SRMR).
- SRMR is a non-intrusive metric for speech quality and intelligibility based on
- a modulation spectral representation of the speech signal.
- This code is translated from SRMRToolbox and `SRMRpy`_.
- Args:
- preds: shape ``(..., time)``
- fs: the sampling rate
- n_cochlear_filters: Number of filters in the acoustic filterbank
- low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
- min_cf: Center frequency in Hz of the first modulation filter.
- max_cf: Center frequency in Hz of the last modulation filter. If None is given,
- then 30 Hz will be used for `norm==False`, otherwise 128 Hz will be used.
- norm: Use modulation spectrum energy normalization
- fast: Use the faster version based on the gammatonegram.
- Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch,
- setting `fast=True` may slow down the speed for calculating this metric on GPU.
- .. hint::
- Usingsing this metrics requires you to have ``gammatone`` and ``torchaudio`` installed.
- Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
- and ``pip install git+https://github.com/detly/gammatone``.
- .. attention::
- This implementation is experimental, and might not be consistent with the matlab
- implementation SRMRToolbox, especially the fast implementation.
- The slow versions, a) ``fast=False, norm=False, max_cf=128``, b) ``fast=False, norm=True, max_cf=30``,
- have a relatively small inconsistency.
- Returns:
- Scalar tensor with srmr value with shape ``(...)``
- Raises:
- ModuleNotFoundError:
- If ``gammatone`` or ``torchaudio`` package is not installed
- Example:
- >>> from torch import randn
- >>> from torchmetrics.functional.audio import speech_reverberation_modulation_energy_ratio
- >>> preds = randn(8000)
- >>> speech_reverberation_modulation_energy_ratio(preds, 8000)
- tensor([0.3191], dtype=torch.float64)
- """
- if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
- raise ModuleNotFoundError(
- "speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
- " `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
- "``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
- )
- from gammatone.fftweight import fft_gtgram
- from torchaudio.functional.filtering import lfilter
- _srmr_arg_validate(
- fs=fs,
- n_cochlear_filters=n_cochlear_filters,
- low_freq=low_freq,
- min_cf=min_cf,
- max_cf=max_cf,
- norm=norm,
- fast=fast,
- )
- shape = preds.shape
- preds = preds.reshape(1, -1) if len(shape) == 1 else preds.reshape(-1, shape[-1])
- num_batch, time = preds.shape
- # convert int type to float
- if not torch.is_floating_point(preds):
- preds = preds.to(torch.float64) / torch.finfo(preds.dtype).max
- # norm values in preds to [-1, 1], as lfilter requires an input in this range
- max_vals = preds.abs().max(dim=-1, keepdim=True).values
- val_norm = torch.where(
- max_vals > 1,
- max_vals,
- torch.tensor(1.0, dtype=max_vals.dtype, device=max_vals.device),
- )
- preds = preds / val_norm
- w_length_s = 0.256
- w_inc_s = 0.064
- # Computing gammatone envelopes
- if fast:
- rank_zero_warn("`fast=True` may slow down the speed of SRMR metric on GPU.")
- mfs = 400.0
- temp = []
- preds_np = preds.detach().cpu().numpy()
- for b in range(num_batch):
- gt_env_b = fft_gtgram(preds_np[b], fs, 0.010, 0.0025, n_cochlear_filters, low_freq)
- temp.append(torch.tensor(gt_env_b))
- gt_env = torch.stack(temp, dim=0).to(device=preds.device)
- else:
- fcoefs = _make_erb_filters(fs, n_cochlear_filters, low_freq, device=preds.device) # [N_filters, 10]
- gt_env = torch.abs(_hilbert(_erb_filterbank(preds, fcoefs))) # [B, N_filters, time]
- mfs = fs
- w_length = ceil(w_length_s * mfs)
- w_inc = ceil(w_inc_s * mfs)
- # Computing modulation filterbank with Q = 2 and 8 channels
- if max_cf is None:
- max_cf = 30 if norm else 128
- _, mf, cutoffs, _ = _compute_modulation_filterbank_and_cutoffs(
- min_cf, max_cf, n=8, fs=mfs, q=2, device=preds.device
- )
- num_frames = int(1 + (time - w_length) // w_inc)
- w = torch.hamming_window(w_length + 1, dtype=torch.float64, device=preds.device)[:-1]
- mod_out = lfilter(
- gt_env.unsqueeze(-2).expand(-1, -1, mf.shape[0], -1), mf[:, 1, :], mf[:, 0, :], clamp=False, batching=True
- ) # [B, N_filters, 8, time]
- # pad signal if it's shorter than window or it is not multiple of wInc
- padding = (0, max(ceil(time / w_inc) * w_inc - time, w_length - time))
- mod_out_pad = pad(mod_out, pad=padding, mode="constant", value=0)
- mod_out_frame = mod_out_pad.unfold(-1, w_length, w_inc)
- energy = ((mod_out_frame[..., :num_frames, :] * w) ** 2).sum(dim=-1) # [B, N_filters, 8, n_frames]
- if norm:
- energy = _normalize_energy(energy)
- erbs = torch.flipud(_calc_erbs(low_freq, fs, n_cochlear_filters, device=preds.device))
- avg_energy = torch.mean(energy, dim=-1)
- total_energy = torch.sum(avg_energy.reshape(num_batch, -1), dim=-1)
- ac_energy = torch.sum(avg_energy, dim=2)
- ac_perc = ac_energy * 100 / total_energy.reshape(-1, 1)
- ac_perc_cumsum = ac_perc.flip(-1).cumsum(-1)
- k90perc_idx = torch.nonzero((ac_perc_cumsum > 90).cumsum(-1) == 1)[:, 1]
- bw = erbs[k90perc_idx]
- temp = []
- for b in range(num_batch):
- score = _cal_srmr_score(bw[b], avg_energy[b], cutoffs=cutoffs)
- temp.append(score)
- score = torch.stack(temp)
- return score.reshape(*shape[:-1]) if len(shape) > 1 else score # recover original shape
- def _srmr_arg_validate(
- fs: int,
- n_cochlear_filters: int = 23,
- low_freq: float = 125,
- min_cf: float = 4,
- max_cf: Optional[float] = 128,
- norm: bool = False,
- fast: bool = False,
- ) -> None:
- """Validate the arguments for speech_reverberation_modulation_energy_ratio.
- Args:
- fs: the sampling rate
- n_cochlear_filters: Number of filters in the acoustic filterbank
- low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
- min_cf: Center frequency in Hz of the first modulation filter.
- max_cf: Center frequency in Hz of the last modulation filter. If None is given,
- norm: Use modulation spectrum energy normalization
- fast: Use the faster version based on the gammatonegram.
- """
- if not (isinstance(fs, int) and fs > 0):
- raise ValueError(f"Expected argument `fs` to be an int larger than 0, but got {fs}")
- if not (isinstance(n_cochlear_filters, int) and n_cochlear_filters > 0):
- raise ValueError(
- f"Expected argument `n_cochlear_filters` to be an int larger than 0, but got {n_cochlear_filters}"
- )
- if not ((isinstance(low_freq, (float, int))) and low_freq > 0):
- raise ValueError(f"Expected argument `low_freq` to be a float larger than 0, but got {low_freq}")
- if not ((isinstance(min_cf, (float, int))) and min_cf > 0):
- raise ValueError(f"Expected argument `min_cf` to be a float larger than 0, but got {min_cf}")
- if max_cf is not None and not ((isinstance(max_cf, (float, int))) and max_cf > 0):
- raise ValueError(f"Expected argument `max_cf` to be a float larger than 0, but got {max_cf}")
- if not isinstance(norm, bool):
- raise ValueError("Expected argument `norm` to be a bool value")
- if not isinstance(fast, bool):
- raise ValueError("Expected argument `fast` to be a bool value")
|