__init__.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from torchmetrics.audio.pit import PermutationInvariantTraining
  15. from torchmetrics.audio.sdr import (
  16. ScaleInvariantSignalDistortionRatio,
  17. SignalDistortionRatio,
  18. SourceAggregatedSignalDistortionRatio,
  19. )
  20. from torchmetrics.audio.snr import (
  21. ComplexScaleInvariantSignalNoiseRatio,
  22. ScaleInvariantSignalNoiseRatio,
  23. SignalNoiseRatio,
  24. )
  25. from torchmetrics.utilities.imports import (
  26. _GAMMATONE_AVAILABLE,
  27. _LIBROSA_AVAILABLE,
  28. _ONNXRUNTIME_AVAILABLE,
  29. _PESQ_AVAILABLE,
  30. _PYSTOI_AVAILABLE,
  31. _REQUESTS_AVAILABLE,
  32. _SCIPI_AVAILABLE,
  33. _TORCHAUDIO_AVAILABLE,
  34. )
  35. if _SCIPI_AVAILABLE:
  36. import scipy.signal
  37. # back compatibility patch due to SMRMpy using scipy.signal.hamming
  38. if not hasattr(scipy.signal, "hamming"):
  39. scipy.signal.hamming = scipy.signal.windows.hamming
  40. __all__ = [
  41. "ComplexScaleInvariantSignalNoiseRatio",
  42. "PermutationInvariantTraining",
  43. "ScaleInvariantSignalDistortionRatio",
  44. "ScaleInvariantSignalNoiseRatio",
  45. "SignalDistortionRatio",
  46. "SignalNoiseRatio",
  47. "SourceAggregatedSignalDistortionRatio",
  48. ]
  49. if _PESQ_AVAILABLE:
  50. from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
  51. __all__ += ["PerceptualEvaluationSpeechQuality"]
  52. if _PYSTOI_AVAILABLE:
  53. from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
  54. __all__ += ["ShortTimeObjectiveIntelligibility"]
  55. if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE:
  56. from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio
  57. __all__ += ["SpeechReverberationModulationEnergyRatio"]
  58. if _LIBROSA_AVAILABLE and _ONNXRUNTIME_AVAILABLE:
  59. from torchmetrics.audio.dnsmos import DeepNoiseSuppressionMeanOpinionScore
  60. __all__ += ["DeepNoiseSuppressionMeanOpinionScore"]
  61. if _LIBROSA_AVAILABLE and _REQUESTS_AVAILABLE:
  62. from torchmetrics.audio.nisqa import NonIntrusiveSpeechQualityAssessment
  63. __all__ += ["NonIntrusiveSpeechQualityAssessment"]