srmr.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor, tensor
  17. from torchmetrics.functional.audio.srmr import (
  18. _srmr_arg_validate,
  19. speech_reverberation_modulation_energy_ratio,
  20. )
  21. from torchmetrics.metric import Metric
  22. from torchmetrics.utilities.imports import (
  23. _GAMMATONE_AVAILABLE,
  24. _MATPLOTLIB_AVAILABLE,
  25. _TORCHAUDIO_AVAILABLE,
  26. )
  27. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  28. if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE]):
  29. __doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"]
  30. elif not _MATPLOTLIB_AVAILABLE:
  31. __doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"]
  32. class SpeechReverberationModulationEnergyRatio(Metric):
  33. """Calculate `Speech-to-Reverberation Modulation Energy Ratio`_ (SRMR).
  34. SRMR is a non-intrusive metric for speech quality and intelligibility based on
  35. a modulation spectral representation of the speech signal.
  36. This code is translated from SRMRToolbox and `SRMRpy`_.
  37. As input to ``forward`` and ``update`` the metric accepts the following input
  38. - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  39. As output of `forward` and `compute` the metric returns the following output
  40. - ``srmr`` (:class:`~torch.Tensor`): float scaler tensor
  41. .. hint::
  42. Using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed.
  43. Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
  44. and ``pip install git+https://github.com/detly/gammatone``.
  45. .. attention::
  46. This implementation is experimental, and might not be consistent with the matlab
  47. implementation SRMRToolbox, especially the fast implementation.
  48. The slow versions, a) ``fast=False, norm=False, max_cf=128``, b) ``fast=False, norm=True, max_cf=30``,
  49. have a relatively small inconsistency.
  50. Args:
  51. fs: the sampling rate
  52. n_cochlear_filters: Number of filters in the acoustic filterbank
  53. low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
  54. min_cf: Center frequency in Hz of the first modulation filter.
  55. max_cf: Center frequency in Hz of the last modulation filter. If None is given,
  56. then 30 Hz will be used for `norm==False`, otherwise 128 Hz will be used.
  57. norm: Use modulation spectrum energy normalization
  58. fast: Use the faster version based on the gammatonegram.
  59. Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch,
  60. setting `fast=True` may slow down the speed for calculating this metric on GPU.
  61. Raises:
  62. ModuleNotFoundError:
  63. If ``gammatone`` or ``torchaudio`` package is not installed
  64. Example:
  65. >>> from torch import randn
  66. >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
  67. >>> preds = randn(8000)
  68. >>> srmr = SpeechReverberationModulationEnergyRatio(8000)
  69. >>> srmr(preds)
  70. tensor(0.3191)
  71. """
  72. msum: Tensor
  73. total: Tensor
  74. full_state_update: bool = False
  75. is_differentiable: bool = True
  76. higher_is_better: bool = True
  77. plot_lower_bound: Optional[float] = None
  78. plot_upper_bound: Optional[float] = None
  79. def __init__(
  80. self,
  81. fs: int,
  82. n_cochlear_filters: int = 23,
  83. low_freq: float = 125,
  84. min_cf: float = 4,
  85. max_cf: Optional[float] = None,
  86. norm: bool = False,
  87. fast: bool = False,
  88. **kwargs: Any,
  89. ) -> None:
  90. super().__init__(**kwargs)
  91. if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
  92. raise ModuleNotFoundError(
  93. "speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
  94. " `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
  95. "``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
  96. )
  97. _srmr_arg_validate(
  98. fs=fs,
  99. n_cochlear_filters=n_cochlear_filters,
  100. low_freq=low_freq,
  101. min_cf=min_cf,
  102. max_cf=max_cf,
  103. norm=norm,
  104. fast=fast,
  105. )
  106. self.fs = fs
  107. self.n_cochlear_filters = n_cochlear_filters
  108. self.low_freq = low_freq
  109. self.min_cf = min_cf
  110. self.max_cf = max_cf
  111. self.norm = norm
  112. self.fast = fast
  113. self.add_state("msum", default=tensor(0.0), dist_reduce_fx="sum")
  114. self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
  115. def update(self, preds: Tensor) -> None:
  116. """Update state with predictions."""
  117. metric_val_batch = speech_reverberation_modulation_energy_ratio(
  118. preds, self.fs, self.n_cochlear_filters, self.low_freq, self.min_cf, self.max_cf, self.norm, self.fast
  119. ).to(self.msum.device)
  120. self.msum += metric_val_batch.sum()
  121. self.total += metric_val_batch.numel()
  122. def compute(self) -> Tensor:
  123. """Compute metric."""
  124. return self.msum / self.total
  125. def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
  126. """Plot a single or multiple values from the metric.
  127. Args:
  128. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  129. If no value is provided, will automatically call `metric.compute` and plot that result.
  130. ax: An matplotlib axis object. If provided will add plot to that axis
  131. Returns:
  132. Figure and Axes object
  133. Raises:
  134. ModuleNotFoundError:
  135. If `matplotlib` is not installed
  136. .. plot::
  137. :scale: 75
  138. >>> # Example plotting a single value
  139. >>> import torch
  140. >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
  141. >>> metric = SpeechReverberationModulationEnergyRatio(8000)
  142. >>> metric.update(torch.rand(8000))
  143. >>> fig_, ax_ = metric.plot()
  144. .. plot::
  145. :scale: 75
  146. >>> # Example plotting multiple values
  147. >>> import torch
  148. >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
  149. >>> metric = SpeechReverberationModulationEnergyRatio(8000)
  150. >>> values = [ ]
  151. >>> for _ in range(10):
  152. ... values.append(metric(torch.rand(8000)))
  153. >>> fig_, ax_ = metric.plot(values)
  154. """
  155. return self._plot(val, ax)