dnsmos.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. import torch
  17. from torch import Tensor, tensor
  18. from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score
  19. from torchmetrics.metric import Metric
  20. from torchmetrics.utilities.imports import (
  21. _LIBROSA_AVAILABLE,
  22. _MATPLOTLIB_AVAILABLE,
  23. _ONNXRUNTIME_AVAILABLE,
  24. _REQUESTS_AVAILABLE,
  25. )
  26. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  27. __doctest_requires__ = {"DeepNoiseSuppressionMeanOpinionScore": ["requests", "librosa", "onnxruntime"]}
  28. if not _MATPLOTLIB_AVAILABLE:
  29. __doctest_skip__ = ["DeepNoiseSuppressionMeanOpinionScore.plot"]
  30. class DeepNoiseSuppressionMeanOpinionScore(Metric):
  31. """Calculate `Deep Noise Suppression performance evaluation based on Mean Opinion Score`_ (DNSMOS).
  32. Human subjective evaluation is the ”gold standard” to evaluate speech quality optimized for human perception.
  33. Perceptual objective metrics serve as a proxy for subjective scores. The conventional and widely used metrics
  34. require a reference clean speech signal, which is unavailable in real recordings. The no-reference approaches
  35. correlate poorly with human ratings and are not widely adopted in the research community. One of the biggest
  36. use cases of these perceptual objective metrics is to evaluate noise suppression algorithms. DNSMOS generalizes
  37. well in challenging test conditions with a high correlation to human ratings in stack ranking noise suppression
  38. methods. More details can be found in `DNSMOS paper <https://arxiv.org/abs/2010.15258>`_ and
  39. `DNSMOS P.835 paper <https://arxiv.org/abs/2110.01763>`_.
  40. As input to ``forward`` and ``update`` the metric accepts the following input
  41. - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  42. As output of ``forward`` and ``compute`` the metric returns the following output
  43. - ``dnsmos`` (:class:`~torch.Tensor`): float tensor of DNSMOS values reduced across the batch
  44. with shape ``(...,4)`` indicating [p808_mos, mos_sig, mos_bak, mos_ovr] in the last dim.
  45. .. hint::
  46. Using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed.
  47. Install as ``pip install torchmetrics['audio']`` or alternatively `pip install librosa onnxruntime-gpu requests`
  48. (if you do not have GPU enabled machine install `onnxruntime` instead of `onnxruntime-gpu`)
  49. .. caution::
  50. The ``forward`` and ``compute`` methods in this class return a reduced DNSMOS value
  51. for a batch. To obtain the DNSMOS value for each sample, you may use the functional counterpart in
  52. :func:`~torchmetrics.functional.audio.dnsmos.deep_noise_suppression_mean_opinion_score`.
  53. Args:
  54. fs: sampling frequency
  55. personalized: whether interfering speaker is penalized
  56. device: the device used for calculating DNSMOS, can be cpu or cuda:n, where n is the index of gpu.
  57. If None is given, then the device of input is used.
  58. num_threads: number of threads to use for onnxruntime CPU inference.
  59. cache_session: whether to cache the onnx session. By default this is true, meaning that repeated calls to this
  60. method is faster than if this was set to False, the consequence is that the session will be cached in
  61. memory until the process is terminated.
  62. Raises:
  63. ModuleNotFoundError:
  64. If ``librosa``, ``onnxruntime`` or ``requests`` packages are not installed
  65. Example:
  66. >>> from torch import randn
  67. >>> from torchmetrics.audio import DeepNoiseSuppressionMeanOpinionScore
  68. >>> preds = randn(8000)
  69. >>> dnsmos = DeepNoiseSuppressionMeanOpinionScore(8000, False)
  70. >>> dnsmos(preds)
  71. tensor([2.2..., 2.0..., 1.1..., 1.2...], dtype=torch.float64)
  72. """
  73. sum_dnsmos: Tensor
  74. total: Tensor
  75. full_state_update: bool = False
  76. is_differentiable: bool = False
  77. higher_is_better: bool = True
  78. plot_lower_bound: float = 0
  79. plot_upper_bound: float = 5
  80. def __init__(
  81. self,
  82. fs: int,
  83. personalized: bool,
  84. device: Optional[str] = None,
  85. num_threads: Optional[int] = None,
  86. cache_sessions: bool = True,
  87. **kwargs: Any,
  88. ) -> None:
  89. super().__init__(**kwargs)
  90. if not _LIBROSA_AVAILABLE or not _ONNXRUNTIME_AVAILABLE or not _REQUESTS_AVAILABLE:
  91. raise ModuleNotFoundError(
  92. "DNSMOS metric requires that librosa, onnxruntime and requests are installed."
  93. " Install as `pip install librosa onnxruntime-gpu requests`."
  94. )
  95. if fs <= 0 or not isinstance(fs, int):
  96. raise ValueError("Argument `fs` must be a positive integer.")
  97. self.fs = fs
  98. if not isinstance(personalized, bool):
  99. raise ValueError("Argument `personalized` must be a boolean.")
  100. self.personalized = personalized
  101. self.cal_device = device
  102. self.num_threads = num_threads
  103. self.cache_sessions = cache_sessions
  104. self.add_state("sum_dnsmos", default=tensor([0, 0, 0, 0], dtype=torch.float64), dist_reduce_fx="sum")
  105. self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
  106. def update(self, preds: Tensor) -> None:
  107. """Update state with predictions."""
  108. metric_batch = deep_noise_suppression_mean_opinion_score(
  109. preds=preds,
  110. fs=self.fs,
  111. personalized=self.personalized,
  112. device=self.cal_device,
  113. num_threads=self.num_threads,
  114. cache_session=self.cache_sessions,
  115. ).to(self.sum_dnsmos.device)
  116. self.sum_dnsmos += metric_batch.reshape(-1, 4).sum(dim=0)
  117. self.total += metric_batch.reshape(-1, 4).shape[0]
  118. def compute(self) -> Tensor:
  119. """Compute metric."""
  120. return self.sum_dnsmos / self.total
  121. def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
  122. """Plot a single or multiple values from the metric.
  123. Args:
  124. val: Either a single result from calling ``metric.forward`` or ``metric.compute`` or a list of these
  125. results. If no value is provided, will automatically call ``metric.compute`` and plot that result.
  126. ax: A matplotlib axis object. If provided will add plot to that axis
  127. Returns:
  128. Figure and Axes object
  129. Raises:
  130. ModuleNotFoundError:
  131. If ``matplotlib`` is not installed
  132. .. plot::
  133. :scale: 75
  134. >>> # Example plotting a single value
  135. >>> import torch
  136. >>> from torchmetrics.audio import DeepNoiseSuppressionMeanOpinionScore
  137. >>> metric = DeepNoiseSuppressionMeanOpinionScore(8000, False)
  138. >>> metric.update(torch.rand(8000))
  139. >>> fig_, ax_ = metric.plot()
  140. .. plot::
  141. :scale: 75
  142. >>> # Example plotting multiple values
  143. >>> import torch
  144. >>> from torchmetrics.audio import DeepNoiseSuppressionMeanOpinionScore
  145. >>> metric = DeepNoiseSuppressionMeanOpinionScore(8000, False)
  146. >>> values = [ ]
  147. >>> for _ in range(10):
  148. ... values.append(metric(torch.rand(8000)))
  149. >>> fig_, ax_ = metric.plot(values)
  150. """
  151. return self._plot(val, ax)