stoi.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor, tensor
  17. from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility
  18. from torchmetrics.metric import Metric
  19. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYSTOI_AVAILABLE
  20. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  21. __doctest_requires__ = {"ShortTimeObjectiveIntelligibility": ["pystoi"]}
  22. if not _MATPLOTLIB_AVAILABLE:
  23. __doctest_skip__ = ["ShortTimeObjectiveIntelligibility.plot"]
  24. class ShortTimeObjectiveIntelligibility(Metric):
  25. r"""Calculate STOI (Short-Time Objective Intelligibility) metric for evaluating speech signals.
  26. Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due
  27. to additive noise, single-/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations.
  28. The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good
  29. alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are
  30. interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms,
  31. on speech intelligibility. Description taken from `Cees Taal's website`_ and for further details see `STOI ref1`_
  32. and `STOI ref2`_.
  33. This metric is a wrapper for the `pystoi package`_. As the implementation backend implementation only supports
  34. calculations on CPU, all input will automatically be moved to CPU to perform the metric calculation before being
  35. moved back to the original device.
  36. As input to `forward` and `update` the metric accepts the following input
  37. - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  38. - ``target`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
  39. As output of `forward` and `compute` the metric returns the following output
  40. - ``stoi`` (:class:`~torch.Tensor`): float scalar tensor
  41. .. hint::
  42. Using this metrics requires you to have ``pystoi`` install. Either install as ``pip install
  43. torchmetrics[audio]`` or ``pip install pystoi``.
  44. Args:
  45. fs: sampling frequency (Hz)
  46. extended: whether to use the extended STOI described in `STOI ref3`_.
  47. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  48. Raises:
  49. ModuleNotFoundError:
  50. If ``pystoi`` package is not installed
  51. Example:
  52. >>> from torch import randn
  53. >>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
  54. >>> preds = randn(8000)
  55. >>> target = randn(8000)
  56. >>> stoi = ShortTimeObjectiveIntelligibility(8000, False)
  57. >>> stoi(preds, target)
  58. tensor(-0.084...)
  59. """
  60. sum_stoi: Tensor
  61. total: Tensor
  62. full_state_update: bool = False
  63. is_differentiable: bool = False
  64. higher_is_better: bool = True
  65. plot_lower_bound: float = 0.0
  66. plot_upper_bound: float = 1.0
  67. def __init__(
  68. self,
  69. fs: int,
  70. extended: bool = False,
  71. **kwargs: Any,
  72. ) -> None:
  73. super().__init__(**kwargs)
  74. if not _PYSTOI_AVAILABLE:
  75. raise ModuleNotFoundError(
  76. "STOI metric requires that `pystoi` is installed."
  77. " Either install as `pip install torchmetrics[audio]` or `pip install pystoi`."
  78. )
  79. self.fs = fs
  80. self.extended = extended
  81. self.add_state("sum_stoi", default=tensor(0.0), dist_reduce_fx="sum")
  82. self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
  83. def update(self, preds: Tensor, target: Tensor) -> None:
  84. """Update state with predictions and targets."""
  85. stoi_batch = short_time_objective_intelligibility(preds, target, self.fs, self.extended, False).to(
  86. self.sum_stoi.device
  87. )
  88. self.sum_stoi += stoi_batch.sum()
  89. self.total += stoi_batch.numel()
  90. def compute(self) -> Tensor:
  91. """Compute metric."""
  92. return self.sum_stoi / self.total
  93. def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
  94. """Plot a single or multiple values from the metric.
  95. Args:
  96. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  97. If no value is provided, will automatically call `metric.compute` and plot that result.
  98. ax: An matplotlib axis object. If provided will add plot to that axis
  99. Returns:
  100. Figure and Axes object
  101. Raises:
  102. ModuleNotFoundError:
  103. If `matplotlib` is not installed
  104. .. plot::
  105. :scale: 75
  106. >>> # Example plotting a single value
  107. >>> from torch import randn
  108. >>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
  109. >>> preds = randn(8000)
  110. >>> target = randn(8000)
  111. >>> metric = ShortTimeObjectiveIntelligibility(8000, False)
  112. >>> metric.update(preds, target)
  113. >>> fig_, ax_ = metric.plot()
  114. .. plot::
  115. :scale: 75
  116. >>> # Example plotting multiple values
  117. >>> from torch import randn
  118. >>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
  119. >>> metric = ShortTimeObjectiveIntelligibility(8000, False)
  120. >>> preds = randn(8000)
  121. >>> target = randn(8000)
  122. >>> values = [ ]
  123. >>> for _ in range(10):
  124. ... values.append(metric(preds, target))
  125. >>> fig_, ax_ = metric.plot(values)
  126. """
  127. return self._plot(val, ax)