nisqa.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor, tensor
  17. from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment
  18. from torchmetrics.metric import Metric
  19. from torchmetrics.utilities.imports import (
  20. _LIBROSA_AVAILABLE,
  21. _MATPLOTLIB_AVAILABLE,
  22. _REQUESTS_AVAILABLE,
  23. )
  24. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  25. __doctest_requires__ = {"NonIntrusiveSpeechQualityAssessment": ["librosa", "requests"]}
  26. if not _MATPLOTLIB_AVAILABLE:
  27. __doctest_skip__ = ["NonIntrusiveSpeechQualityAssessment.plot"]
  28. class NonIntrusiveSpeechQualityAssessment(Metric):
  29. """`Non-Intrusive Speech Quality Assessment`_ (NISQA v2.0) [1], [2].
  30. As input to ``forward`` and ``update`` the metric accepts the following input
  31. - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  32. As output of ``forward`` and ``compute`` the metric returns the following output
  33. - ``nisqa`` (:class:`~torch.Tensor`): float tensor reduced across the batch with shape ``(5,)`` corresponding to
  34. overall MOS, noisiness, discontinuity, coloration and loudness in that order
  35. .. hint::
  36. Using this metric requires you to have ``librosa`` and ``requests`` installed. Install as
  37. ``pip install librosa requests``.
  38. .. caution::
  39. The ``forward`` and ``compute`` methods in this class return values reduced across the batch. To obtain
  40. values for each sample, you may use the functional counterpart
  41. :func:`~torchmetrics.functional.audio.nisqa.non_intrusive_speech_quality_assessment`.
  42. Args:
  43. fs: sampling frequency of input
  44. Raises:
  45. ModuleNotFoundError:
  46. If ``librosa`` or ``requests`` are not installed
  47. Example:
  48. >>> import torch
  49. >>> from torchmetrics.audio import NonIntrusiveSpeechQualityAssessment
  50. >>> _ = torch.manual_seed(42)
  51. >>> preds = torch.randn(16000)
  52. >>> nisqa = NonIntrusiveSpeechQualityAssessment(16000)
  53. >>> nisqa(preds)
  54. tensor([1.0433, 1.9545, 2.6087, 1.3460, 1.7117])
  55. References:
  56. - [1] G. Mittag and S. Möller, "Non-intrusive speech quality assessment for super-wideband speech communication
  57. networks", in Proc. ICASSP, 2019.
  58. - [2] G. Mittag, B. Naderi, A. Chehadi and S. Möller, "NISQA: A deep CNN-self-attention model for
  59. multidimensional speech quality prediction with crowdsourced datasets", in Proc. INTERSPEECH, 2021.
  60. """
  61. sum_nisqa: Tensor
  62. total: Tensor
  63. full_state_update: bool = False
  64. is_differentiable: bool = False
  65. higher_is_better: bool = True
  66. plot_lower_bound: float = 0.0
  67. plot_upper_bound: float = 5.0
  68. def __init__(self, fs: int, **kwargs: Any) -> None:
  69. super().__init__(**kwargs)
  70. if not _LIBROSA_AVAILABLE or not _REQUESTS_AVAILABLE:
  71. raise ModuleNotFoundError(
  72. "NISQA metric requires that librosa and requests are installed. "
  73. "Install as `pip install librosa requests`."
  74. )
  75. if not isinstance(fs, int) or fs <= 0:
  76. raise ValueError(f"Argument `fs` expected to be a positive integer, but got {fs}")
  77. self.fs = fs
  78. self.add_state("sum_nisqa", default=tensor([0.0, 0.0, 0.0, 0.0, 0.0]), dist_reduce_fx="sum")
  79. self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
  80. def update(self, preds: Tensor) -> None:
  81. """Update state with predictions."""
  82. nisqa_batch = non_intrusive_speech_quality_assessment(
  83. preds,
  84. self.fs,
  85. ).to(self.sum_nisqa.device)
  86. nisqa_batch = nisqa_batch.reshape(-1, 5)
  87. self.sum_nisqa += nisqa_batch.sum(dim=0)
  88. self.total += nisqa_batch.shape[0]
  89. def compute(self) -> Tensor:
  90. """Compute metric."""
  91. return self.sum_nisqa / self.total
  92. def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
  93. """Plot a single or multiple values from the metric.
  94. Args:
  95. val: Either a single result from calling ``metric.forward`` or ``metric.compute`` or a list of these
  96. results. If no value is provided, will automatically call ``metric.compute`` and plot that result.
  97. ax: A matplotlib axis object. If provided will add plot to that axis
  98. Returns:
  99. Figure and Axes object
  100. Raises:
  101. ModuleNotFoundError:
  102. If ``matplotlib`` is not installed
  103. .. plot::
  104. :scale: 75
  105. >>> # Example plotting a single value
  106. >>> import torch
  107. >>> from torchmetrics.audio import NonIntrusiveSpeechQualityAssessment
  108. >>> metric = NonIntrusiveSpeechQualityAssessment(16000)
  109. >>> metric.update(torch.randn(16000))
  110. >>> fig_, ax_ = metric.plot()
  111. .. plot::
  112. :scale: 75
  113. >>> # Example plotting multiple values
  114. >>> import torch
  115. >>> from torchmetrics.audio import NonIntrusiveSpeechQualityAssessment
  116. >>> metric = NonIntrusiveSpeechQualityAssessment(16000)
  117. >>> values = []
  118. >>> for _ in range(10):
  119. ... values.append(metric(torch.randn(16000)))
  120. >>> fig_, ax_ = metric.plot(values)
  121. """
  122. return self._plot(val, ax)