| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- from typing import Any, Callable, Optional
- from torch import Tensor
- from typing_extensions import Literal
- from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate
- from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
- from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio
- from torchmetrics.utilities.prints import _deprecated_root_import_func
- def _permutation_invariant_training(
- preds: Tensor,
- target: Tensor,
- metric_func: Callable,
- mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
- eval_func: Literal["max", "min"] = "max",
- **kwargs: Any,
- ) -> tuple[Tensor, Tensor]:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]])
- >>> target = tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]])
- >>> best_metric, best_perm = _permutation_invariant_training(
- ... preds, target, _scale_invariant_signal_distortion_ratio)
- >>> best_metric
- tensor([-5.1091])
- >>> best_perm
- tensor([[0, 1]])
- >>> pit_permutate(preds, best_perm)
- tensor([[[-0.0579, 0.3560, -0.9604],
- [-0.1719, 0.3205, 0.2951]]])
- """
- _deprecated_root_import_func("permutation_invariant_training", "audio")
- return permutation_invariant_training(
- preds=preds, target=target, metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs
- )
- def _pit_permutate(preds: Tensor, perm: Tensor) -> Tensor:
- """Wrapper for deprecated import."""
- _deprecated_root_import_func("pit_permutate", "audio")
- return pit_permutate(preds=preds, perm=perm)
- def _scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> target = tensor([3.0, -0.5, 2.0, 7.0])
- >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
- >>> _scale_invariant_signal_distortion_ratio(preds, target)
- tensor(18.4030)
- """
- _deprecated_root_import_func("scale_invariant_signal_distortion_ratio", "audio")
- return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=zero_mean)
- def _signal_distortion_ratio(
- preds: Tensor,
- target: Tensor,
- use_cg_iter: Optional[int] = None,
- filter_length: int = 512,
- zero_mean: bool = False,
- load_diag: Optional[float] = None,
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import randn
- >>> preds = randn(8000)
- >>> target = randn(8000)
- >>> _signal_distortion_ratio(preds, target)
- tensor(-11.9930)
- >>> # use with permutation_invariant_training
- >>> preds = randn(4, 2, 8000) # [batch, spk, time]
- >>> target = randn(4, 2, 8000)
- >>> best_metric, best_perm = _permutation_invariant_training(preds, target, _signal_distortion_ratio)
- >>> best_metric
- tensor([-11.7748, -11.7948, -11.7160, -11.6254])
- >>> best_perm
- tensor([[1, 0],
- [1, 0],
- [1, 0],
- [0, 1]])
- """
- _deprecated_root_import_func("signal_distortion_ratio", "audio")
- return signal_distortion_ratio(
- preds=preds,
- target=target,
- use_cg_iter=use_cg_iter,
- filter_length=filter_length,
- zero_mean=zero_mean,
- load_diag=load_diag,
- )
- def _scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> target = tensor([3.0, -0.5, 2.0, 7.0])
- >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
- >>> _scale_invariant_signal_noise_ratio(preds, target)
- tensor(15.0918)
- """
- _deprecated_root_import_func("scale_invariant_signal_noise_ratio", "audio")
- return scale_invariant_signal_noise_ratio(preds=preds, target=target)
- def _signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> target = tensor([3.0, -0.5, 2.0, 7.0])
- >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
- >>> _signal_noise_ratio(preds, target)
- tensor(16.1805)
- """
- _deprecated_root_import_func("signal_noise_ratio", "audio")
- return signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)
|