_deprecated.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from typing import Optional
  2. from torch import Tensor
  3. from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
  4. from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
  5. from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
  6. from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
  7. from torchmetrics.functional.retrieval.precision import retrieval_precision
  8. from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve
  9. from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
  10. from torchmetrics.functional.retrieval.recall import retrieval_recall
  11. from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
  12. from torchmetrics.utilities.prints import _deprecated_root_import_func
  13. def _retrieval_average_precision(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
  14. """Wrapper for deprecated import.
  15. >>> from torch import tensor
  16. >>> preds = tensor([0.2, 0.3, 0.5])
  17. >>> target = tensor([True, False, True])
  18. >>> _retrieval_average_precision(preds, target)
  19. tensor(0.8333)
  20. """
  21. _deprecated_root_import_func("retrieval_average_precision", "retrieval")
  22. return retrieval_average_precision(preds=preds, target=target, top_k=top_k)
  23. def _retrieval_fall_out(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
  24. """Wrapper for deprecated import.
  25. >>> from torch import tensor
  26. >>> preds = tensor([0.2, 0.3, 0.5])
  27. >>> target = tensor([True, False, True])
  28. >>> _retrieval_fall_out(preds, target, top_k=2)
  29. tensor(1.)
  30. """
  31. _deprecated_root_import_func("retrieval_fall_out", "retrieval")
  32. return retrieval_fall_out(preds=preds, target=target, top_k=top_k)
  33. def _retrieval_hit_rate(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
  34. """Wrapper for deprecated import.
  35. >>> from torch import tensor
  36. >>> preds = tensor([0.2, 0.3, 0.5])
  37. >>> target = tensor([True, False, True])
  38. >>> _retrieval_hit_rate(preds, target, top_k=2)
  39. tensor(1.)
  40. """
  41. _deprecated_root_import_func("retrieval_hit_rate", "retrieval")
  42. return retrieval_hit_rate(preds=preds, target=target, top_k=top_k)
  43. def _retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
  44. """Wrapper for deprecated import.
  45. >>> from torch import tensor
  46. >>> preds = tensor([.1, .2, .3, 4, 70])
  47. >>> target = tensor([10, 0, 0, 1, 5])
  48. >>> _retrieval_normalized_dcg(preds, target)
  49. tensor(0.6957)
  50. """
  51. _deprecated_root_import_func("retrieval_normalized_dcg", "retrieval")
  52. return retrieval_normalized_dcg(preds=preds, target=target, top_k=top_k)
  53. def _retrieval_precision(
  54. preds: Tensor, target: Tensor, top_k: Optional[int] = None, adaptive_k: bool = False
  55. ) -> Tensor:
  56. """Wrapper for deprecated import.
  57. >>> from torch import tensor
  58. >>> preds = tensor([0.2, 0.3, 0.5])
  59. >>> target = tensor([True, False, True])
  60. >>> _retrieval_precision(preds, target, top_k=2)
  61. tensor(0.5000)
  62. """
  63. _deprecated_root_import_func("retrieval_precision", "retrieval")
  64. return retrieval_precision(preds=preds, target=target, top_k=top_k, adaptive_k=adaptive_k)
  65. def _retrieval_precision_recall_curve(
  66. preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False
  67. ) -> tuple[Tensor, Tensor, Tensor]:
  68. """Wrapper for deprecated import.
  69. >>> from torch import tensor
  70. >>> preds = tensor([0.2, 0.3, 0.5])
  71. >>> target = tensor([True, False, True])
  72. >>> precisions, recalls, top_k = _retrieval_precision_recall_curve(preds, target, max_k=2)
  73. >>> precisions
  74. tensor([1.0000, 0.5000])
  75. >>> recalls
  76. tensor([0.5000, 0.5000])
  77. >>> top_k
  78. tensor([1, 2])
  79. """
  80. _deprecated_root_import_func("retrieval_precision_recall_curve", "retrieval")
  81. return retrieval_precision_recall_curve(preds=preds, target=target, max_k=max_k, adaptive_k=adaptive_k)
  82. def _retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor:
  83. """Wrapper for deprecated import.
  84. >>> from torch import tensor
  85. >>> preds = tensor([0.2, 0.3, 0.5])
  86. >>> target = tensor([True, False, True])
  87. >>> _retrieval_r_precision(preds, target)
  88. tensor(0.5000)
  89. """
  90. _deprecated_root_import_func("retrieval_r_precision", "retrieval")
  91. return retrieval_r_precision(preds=preds, target=target)
  92. def _retrieval_recall(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
  93. """Wrapper for deprecated import.
  94. >>> from torch import tensor
  95. >>> preds = tensor([0.2, 0.3, 0.5])
  96. >>> target = tensor([True, False, True])
  97. >>> _retrieval_recall(preds, target, top_k=2)
  98. tensor(0.5000)
  99. """
  100. _deprecated_root_import_func("retrieval_recall", "retrieval")
  101. return retrieval_recall(preds=preds, target=target, top_k=top_k)
  102. def _retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
  103. """Wrapper for deprecated import.
  104. >>> from torch import tensor
  105. >>> preds = tensor([0.2, 0.3, 0.5])
  106. >>> target = tensor([False, True, False])
  107. >>> _retrieval_reciprocal_rank(preds, target)
  108. tensor(0.5000)
  109. """
  110. _deprecated_root_import_func("retrieval_reciprocal_rank", "retrieval")
  111. return retrieval_reciprocal_rank(preds=preds, target=target)