| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- from typing import Optional
- from torch import Tensor
- from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
- from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
- from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
- from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
- from torchmetrics.functional.retrieval.precision import retrieval_precision
- from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve
- from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
- from torchmetrics.functional.retrieval.recall import retrieval_recall
- from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
- from torchmetrics.utilities.prints import _deprecated_root_import_func
- def _retrieval_average_precision(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([0.2, 0.3, 0.5])
- >>> target = tensor([True, False, True])
- >>> _retrieval_average_precision(preds, target)
- tensor(0.8333)
- """
- _deprecated_root_import_func("retrieval_average_precision", "retrieval")
- return retrieval_average_precision(preds=preds, target=target, top_k=top_k)
- def _retrieval_fall_out(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([0.2, 0.3, 0.5])
- >>> target = tensor([True, False, True])
- >>> _retrieval_fall_out(preds, target, top_k=2)
- tensor(1.)
- """
- _deprecated_root_import_func("retrieval_fall_out", "retrieval")
- return retrieval_fall_out(preds=preds, target=target, top_k=top_k)
- def _retrieval_hit_rate(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([0.2, 0.3, 0.5])
- >>> target = tensor([True, False, True])
- >>> _retrieval_hit_rate(preds, target, top_k=2)
- tensor(1.)
- """
- _deprecated_root_import_func("retrieval_hit_rate", "retrieval")
- return retrieval_hit_rate(preds=preds, target=target, top_k=top_k)
- def _retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([.1, .2, .3, 4, 70])
- >>> target = tensor([10, 0, 0, 1, 5])
- >>> _retrieval_normalized_dcg(preds, target)
- tensor(0.6957)
- """
- _deprecated_root_import_func("retrieval_normalized_dcg", "retrieval")
- return retrieval_normalized_dcg(preds=preds, target=target, top_k=top_k)
- def _retrieval_precision(
- preds: Tensor, target: Tensor, top_k: Optional[int] = None, adaptive_k: bool = False
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([0.2, 0.3, 0.5])
- >>> target = tensor([True, False, True])
- >>> _retrieval_precision(preds, target, top_k=2)
- tensor(0.5000)
- """
- _deprecated_root_import_func("retrieval_precision", "retrieval")
- return retrieval_precision(preds=preds, target=target, top_k=top_k, adaptive_k=adaptive_k)
- def _retrieval_precision_recall_curve(
- preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False
- ) -> tuple[Tensor, Tensor, Tensor]:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([0.2, 0.3, 0.5])
- >>> target = tensor([True, False, True])
- >>> precisions, recalls, top_k = _retrieval_precision_recall_curve(preds, target, max_k=2)
- >>> precisions
- tensor([1.0000, 0.5000])
- >>> recalls
- tensor([0.5000, 0.5000])
- >>> top_k
- tensor([1, 2])
- """
- _deprecated_root_import_func("retrieval_precision_recall_curve", "retrieval")
- return retrieval_precision_recall_curve(preds=preds, target=target, max_k=max_k, adaptive_k=adaptive_k)
- def _retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([0.2, 0.3, 0.5])
- >>> target = tensor([True, False, True])
- >>> _retrieval_r_precision(preds, target)
- tensor(0.5000)
- """
- _deprecated_root_import_func("retrieval_r_precision", "retrieval")
- return retrieval_r_precision(preds=preds, target=target)
- def _retrieval_recall(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([0.2, 0.3, 0.5])
- >>> target = tensor([True, False, True])
- >>> _retrieval_recall(preds, target, top_k=2)
- tensor(0.5000)
- """
- _deprecated_root_import_func("retrieval_recall", "retrieval")
- return retrieval_recall(preds=preds, target=target, top_k=top_k)
- def _retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> preds = tensor([0.2, 0.3, 0.5])
- >>> target = tensor([False, True, False])
- >>> _retrieval_reciprocal_rank(preds, target)
- tensor(0.5000)
- """
- _deprecated_root_import_func("retrieval_reciprocal_rank", "retrieval")
- return retrieval_reciprocal_rank(preds=preds, target=target)
|