_deprecated.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from typing import Any, Callable, Optional
  2. from typing_extensions import Literal
  3. from torchmetrics.audio.pit import PermutationInvariantTraining
  4. from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
  5. from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
  6. from torchmetrics.utilities.prints import _deprecated_root_import_class
  7. class _PermutationInvariantTraining(PermutationInvariantTraining):
  8. """Wrapper for deprecated import.
  9. >>> import torch
  10. >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
  11. >>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
  12. >>> target = torch.randn(3, 2, 5) # [batch, spk, time]
  13. >>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
  14. ... mode="speaker-wise", eval_func="max")
  15. >>> pit(preds, target)
  16. tensor(-2.1065)
  17. """
  18. def __init__(
  19. self,
  20. metric_func: Callable,
  21. mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
  22. eval_func: Literal["max", "min"] = "max",
  23. **kwargs: Any,
  24. ) -> None:
  25. _deprecated_root_import_class("PermutationInvariantTraining", "audio")
  26. super().__init__(metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs)
  27. class _ScaleInvariantSignalDistortionRatio(ScaleInvariantSignalDistortionRatio):
  28. """Wrapper for deprecated import.
  29. >>> from torch import tensor
  30. >>> target = tensor([3.0, -0.5, 2.0, 7.0])
  31. >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
  32. >>> si_sdr = _ScaleInvariantSignalDistortionRatio()
  33. >>> si_sdr(preds, target)
  34. tensor(18.4030)
  35. """
  36. def __init__(
  37. self,
  38. zero_mean: bool = False,
  39. **kwargs: Any,
  40. ) -> None:
  41. _deprecated_root_import_class("ScaleInvariantSignalDistortionRatio", "audio")
  42. super().__init__(zero_mean=zero_mean, **kwargs)
  43. class _ScaleInvariantSignalNoiseRatio(ScaleInvariantSignalNoiseRatio):
  44. """Wrapper for deprecated import.
  45. >>> from torch import tensor
  46. >>> target = tensor([3.0, -0.5, 2.0, 7.0])
  47. >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
  48. >>> si_snr = _ScaleInvariantSignalNoiseRatio()
  49. >>> si_snr(preds, target)
  50. tensor(15.0918)
  51. """
  52. def __init__(
  53. self,
  54. **kwargs: Any,
  55. ) -> None:
  56. _deprecated_root_import_class("ScaleInvariantSignalNoiseRatio", "audio")
  57. super().__init__(**kwargs)
  58. class _SignalDistortionRatio(SignalDistortionRatio):
  59. """Wrapper for deprecated import.
  60. >>> import torch
  61. >>> preds = torch.randn(8000)
  62. >>> target = torch.randn(8000)
  63. >>> sdr = _SignalDistortionRatio()
  64. >>> sdr(preds, target)
  65. tensor(-11.9930)
  66. >>> # use with pit
  67. >>> from torchmetrics.functional import signal_distortion_ratio
  68. >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
  69. >>> target = torch.randn(4, 2, 8000)
  70. >>> pit = _PermutationInvariantTraining(signal_distortion_ratio,
  71. ... mode="speaker-wise", eval_func="max")
  72. >>> pit(preds, target)
  73. tensor(-11.7277)
  74. """
  75. def __init__(
  76. self,
  77. use_cg_iter: Optional[int] = None,
  78. filter_length: int = 512,
  79. zero_mean: bool = False,
  80. load_diag: Optional[float] = None,
  81. **kwargs: Any,
  82. ) -> None:
  83. _deprecated_root_import_class("SignalDistortionRatio", "audio")
  84. super().__init__(
  85. use_cg_iter=use_cg_iter, filter_length=filter_length, zero_mean=zero_mean, load_diag=load_diag, **kwargs
  86. )
  87. class _SignalNoiseRatio(SignalNoiseRatio):
  88. """Wrapper for deprecated import.
  89. >>> from torch import tensor
  90. >>> target = tensor([3.0, -0.5, 2.0, 7.0])
  91. >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
  92. >>> snr = _SignalNoiseRatio()
  93. >>> snr(preds, target)
  94. tensor(16.1805)
  95. """
  96. def __init__(
  97. self,
  98. zero_mean: bool = False,
  99. **kwargs: Any,
  100. ) -> None:
  101. _deprecated_root_import_class("SignalNoiseRatio", "audio")
  102. super().__init__(zero_mean=zero_mean, **kwargs)