| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Sequence
- from typing import Any, Optional, Union
- from torch import Tensor, tensor
- from torchmetrics.functional.audio.srmr import (
- _srmr_arg_validate,
- speech_reverberation_modulation_energy_ratio,
- )
- from torchmetrics.metric import Metric
- from torchmetrics.utilities.imports import (
- _GAMMATONE_AVAILABLE,
- _MATPLOTLIB_AVAILABLE,
- _TORCHAUDIO_AVAILABLE,
- )
- from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
- if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE]):
- __doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"]
- elif not _MATPLOTLIB_AVAILABLE:
- __doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"]
- class SpeechReverberationModulationEnergyRatio(Metric):
- """Calculate `Speech-to-Reverberation Modulation Energy Ratio`_ (SRMR).
- SRMR is a non-intrusive metric for speech quality and intelligibility based on
- a modulation spectral representation of the speech signal.
- This code is translated from SRMRToolbox and `SRMRpy`_.
- As input to ``forward`` and ``update`` the metric accepts the following input
- - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
- As output of `forward` and `compute` the metric returns the following output
- - ``srmr`` (:class:`~torch.Tensor`): float scaler tensor
- .. hint::
- Using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed.
- Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
- and ``pip install git+https://github.com/detly/gammatone``.
- .. attention::
- This implementation is experimental, and might not be consistent with the matlab
- implementation SRMRToolbox, especially the fast implementation.
- The slow versions, a) ``fast=False, norm=False, max_cf=128``, b) ``fast=False, norm=True, max_cf=30``,
- have a relatively small inconsistency.
- Args:
- fs: the sampling rate
- n_cochlear_filters: Number of filters in the acoustic filterbank
- low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
- min_cf: Center frequency in Hz of the first modulation filter.
- max_cf: Center frequency in Hz of the last modulation filter. If None is given,
- then 30 Hz will be used for `norm==False`, otherwise 128 Hz will be used.
- norm: Use modulation spectrum energy normalization
- fast: Use the faster version based on the gammatonegram.
- Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch,
- setting `fast=True` may slow down the speed for calculating this metric on GPU.
- Raises:
- ModuleNotFoundError:
- If ``gammatone`` or ``torchaudio`` package is not installed
- Example:
- >>> from torch import randn
- >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
- >>> preds = randn(8000)
- >>> srmr = SpeechReverberationModulationEnergyRatio(8000)
- >>> srmr(preds)
- tensor(0.3191)
- """
- msum: Tensor
- total: Tensor
- full_state_update: bool = False
- is_differentiable: bool = True
- higher_is_better: bool = True
- plot_lower_bound: Optional[float] = None
- plot_upper_bound: Optional[float] = None
- def __init__(
- self,
- fs: int,
- n_cochlear_filters: int = 23,
- low_freq: float = 125,
- min_cf: float = 4,
- max_cf: Optional[float] = None,
- norm: bool = False,
- fast: bool = False,
- **kwargs: Any,
- ) -> None:
- super().__init__(**kwargs)
- if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
- raise ModuleNotFoundError(
- "speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
- " `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
- "``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
- )
- _srmr_arg_validate(
- fs=fs,
- n_cochlear_filters=n_cochlear_filters,
- low_freq=low_freq,
- min_cf=min_cf,
- max_cf=max_cf,
- norm=norm,
- fast=fast,
- )
- self.fs = fs
- self.n_cochlear_filters = n_cochlear_filters
- self.low_freq = low_freq
- self.min_cf = min_cf
- self.max_cf = max_cf
- self.norm = norm
- self.fast = fast
- self.add_state("msum", default=tensor(0.0), dist_reduce_fx="sum")
- self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
- def update(self, preds: Tensor) -> None:
- """Update state with predictions."""
- metric_val_batch = speech_reverberation_modulation_energy_ratio(
- preds, self.fs, self.n_cochlear_filters, self.low_freq, self.min_cf, self.max_cf, self.norm, self.fast
- ).to(self.msum.device)
- self.msum += metric_val_batch.sum()
- self.total += metric_val_batch.numel()
- def compute(self) -> Tensor:
- """Compute metric."""
- return self.msum / self.total
- def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single value
- >>> import torch
- >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
- >>> metric = SpeechReverberationModulationEnergyRatio(8000)
- >>> metric.update(torch.rand(8000))
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> import torch
- >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
- >>> metric = SpeechReverberationModulationEnergyRatio(8000)
- >>> values = [ ]
- >>> for _ in range(10):
- ... values.append(metric(torch.rand(8000)))
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
|