| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- from typing import Any, Callable, Optional
- from typing_extensions import Literal
- from torchmetrics.audio.pit import PermutationInvariantTraining
- from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
- from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
- from torchmetrics.utilities.prints import _deprecated_root_import_class
- class _PermutationInvariantTraining(PermutationInvariantTraining):
- """Wrapper for deprecated import.
- >>> import torch
- >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
- >>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
- >>> target = torch.randn(3, 2, 5) # [batch, spk, time]
- >>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
- ... mode="speaker-wise", eval_func="max")
- >>> pit(preds, target)
- tensor(-2.1065)
- """
- def __init__(
- self,
- metric_func: Callable,
- mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
- eval_func: Literal["max", "min"] = "max",
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("PermutationInvariantTraining", "audio")
- super().__init__(metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs)
- class _ScaleInvariantSignalDistortionRatio(ScaleInvariantSignalDistortionRatio):
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> target = tensor([3.0, -0.5, 2.0, 7.0])
- >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
- >>> si_sdr = _ScaleInvariantSignalDistortionRatio()
- >>> si_sdr(preds, target)
- tensor(18.4030)
- """
- def __init__(
- self,
- zero_mean: bool = False,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("ScaleInvariantSignalDistortionRatio", "audio")
- super().__init__(zero_mean=zero_mean, **kwargs)
- class _ScaleInvariantSignalNoiseRatio(ScaleInvariantSignalNoiseRatio):
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> target = tensor([3.0, -0.5, 2.0, 7.0])
- >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
- >>> si_snr = _ScaleInvariantSignalNoiseRatio()
- >>> si_snr(preds, target)
- tensor(15.0918)
- """
- def __init__(
- self,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("ScaleInvariantSignalNoiseRatio", "audio")
- super().__init__(**kwargs)
- class _SignalDistortionRatio(SignalDistortionRatio):
- """Wrapper for deprecated import.
- >>> import torch
- >>> preds = torch.randn(8000)
- >>> target = torch.randn(8000)
- >>> sdr = _SignalDistortionRatio()
- >>> sdr(preds, target)
- tensor(-11.9930)
- >>> # use with pit
- >>> from torchmetrics.functional import signal_distortion_ratio
- >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
- >>> target = torch.randn(4, 2, 8000)
- >>> pit = _PermutationInvariantTraining(signal_distortion_ratio,
- ... mode="speaker-wise", eval_func="max")
- >>> pit(preds, target)
- tensor(-11.7277)
- """
- def __init__(
- self,
- use_cg_iter: Optional[int] = None,
- filter_length: int = 512,
- zero_mean: bool = False,
- load_diag: Optional[float] = None,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("SignalDistortionRatio", "audio")
- super().__init__(
- use_cg_iter=use_cg_iter, filter_length=filter_length, zero_mean=zero_mean, load_diag=load_diag, **kwargs
- )
- class _SignalNoiseRatio(SignalNoiseRatio):
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> target = tensor([3.0, -0.5, 2.0, 7.0])
- >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
- >>> snr = _SignalNoiseRatio()
- >>> snr(preds, target)
- tensor(16.1805)
- """
- def __init__(
- self,
- zero_mean: bool = False,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("SignalNoiseRatio", "audio")
- super().__init__(zero_mean=zero_mean, **kwargs)
|