_deprecated.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from typing import Any, Callable, Optional
  2. from torch import Tensor
  3. from typing_extensions import Literal
  4. from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate
  5. from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
  6. from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio
  7. from torchmetrics.utilities.prints import _deprecated_root_import_func
  8. def _permutation_invariant_training(
  9. preds: Tensor,
  10. target: Tensor,
  11. metric_func: Callable,
  12. mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
  13. eval_func: Literal["max", "min"] = "max",
  14. **kwargs: Any,
  15. ) -> tuple[Tensor, Tensor]:
  16. """Wrapper for deprecated import.
  17. >>> from torch import tensor
  18. >>> preds = tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]])
  19. >>> target = tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]])
  20. >>> best_metric, best_perm = _permutation_invariant_training(
  21. ... preds, target, _scale_invariant_signal_distortion_ratio)
  22. >>> best_metric
  23. tensor([-5.1091])
  24. >>> best_perm
  25. tensor([[0, 1]])
  26. >>> pit_permutate(preds, best_perm)
  27. tensor([[[-0.0579, 0.3560, -0.9604],
  28. [-0.1719, 0.3205, 0.2951]]])
  29. """
  30. _deprecated_root_import_func("permutation_invariant_training", "audio")
  31. return permutation_invariant_training(
  32. preds=preds, target=target, metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs
  33. )
  34. def _pit_permutate(preds: Tensor, perm: Tensor) -> Tensor:
  35. """Wrapper for deprecated import."""
  36. _deprecated_root_import_func("pit_permutate", "audio")
  37. return pit_permutate(preds=preds, perm=perm)
  38. def _scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
  39. """Wrapper for deprecated import.
  40. >>> from torch import tensor
  41. >>> target = tensor([3.0, -0.5, 2.0, 7.0])
  42. >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
  43. >>> _scale_invariant_signal_distortion_ratio(preds, target)
  44. tensor(18.4030)
  45. """
  46. _deprecated_root_import_func("scale_invariant_signal_distortion_ratio", "audio")
  47. return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=zero_mean)
  48. def _signal_distortion_ratio(
  49. preds: Tensor,
  50. target: Tensor,
  51. use_cg_iter: Optional[int] = None,
  52. filter_length: int = 512,
  53. zero_mean: bool = False,
  54. load_diag: Optional[float] = None,
  55. ) -> Tensor:
  56. """Wrapper for deprecated import.
  57. >>> from torch import randn
  58. >>> preds = randn(8000)
  59. >>> target = randn(8000)
  60. >>> _signal_distortion_ratio(preds, target)
  61. tensor(-11.9930)
  62. >>> # use with permutation_invariant_training
  63. >>> preds = randn(4, 2, 8000) # [batch, spk, time]
  64. >>> target = randn(4, 2, 8000)
  65. >>> best_metric, best_perm = _permutation_invariant_training(preds, target, _signal_distortion_ratio)
  66. >>> best_metric
  67. tensor([-11.7748, -11.7948, -11.7160, -11.6254])
  68. >>> best_perm
  69. tensor([[1, 0],
  70. [1, 0],
  71. [1, 0],
  72. [0, 1]])
  73. """
  74. _deprecated_root_import_func("signal_distortion_ratio", "audio")
  75. return signal_distortion_ratio(
  76. preds=preds,
  77. target=target,
  78. use_cg_iter=use_cg_iter,
  79. filter_length=filter_length,
  80. zero_mean=zero_mean,
  81. load_diag=load_diag,
  82. )
  83. def _scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
  84. """Wrapper for deprecated import.
  85. >>> from torch import tensor
  86. >>> target = tensor([3.0, -0.5, 2.0, 7.0])
  87. >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
  88. >>> _scale_invariant_signal_noise_ratio(preds, target)
  89. tensor(15.0918)
  90. """
  91. _deprecated_root_import_func("scale_invariant_signal_noise_ratio", "audio")
  92. return scale_invariant_signal_noise_ratio(preds=preds, target=target)
  93. def _signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
  94. """Wrapper for deprecated import.
  95. >>> from torch import tensor
  96. >>> target = tensor([3.0, -0.5, 2.0, 7.0])
  97. >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
  98. >>> _signal_noise_ratio(preds, target)
  99. tensor(16.1805)
  100. """
  101. _deprecated_root_import_func("signal_noise_ratio", "audio")
  102. return signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)