_deprecated.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. from typing import Any, Optional
  2. from torchmetrics.retrieval.average_precision import RetrievalMAP
  3. from torchmetrics.retrieval.fall_out import RetrievalFallOut
  4. from torchmetrics.retrieval.hit_rate import RetrievalHitRate
  5. from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG
  6. from torchmetrics.retrieval.precision import RetrievalPrecision
  7. from torchmetrics.retrieval.precision_recall_curve import RetrievalPrecisionRecallCurve, RetrievalRecallAtFixedPrecision
  8. from torchmetrics.retrieval.r_precision import RetrievalRPrecision
  9. from torchmetrics.retrieval.recall import RetrievalRecall
  10. from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR
  11. from torchmetrics.utilities.prints import _deprecated_root_import_class
  12. class _RetrievalFallOut(RetrievalFallOut):
  13. """Wrapper for deprecated import.
  14. >>> from torch import tensor
  15. >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
  16. >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
  17. >>> target = tensor([False, False, True, False, True, False, True])
  18. >>> rfo = _RetrievalFallOut(top_k=2)
  19. >>> rfo(preds, target, indexes=indexes)
  20. tensor(0.5000)
  21. """
  22. def __init__(
  23. self,
  24. empty_target_action: str = "pos",
  25. ignore_index: Optional[int] = None,
  26. top_k: Optional[int] = None,
  27. **kwargs: Any,
  28. ) -> None:
  29. _deprecated_root_import_class("RetrievalFallOut", "retrieval")
  30. super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
  31. class _RetrievalHitRate(RetrievalHitRate):
  32. """Wrapper for deprecated import.
  33. >>> from torch import tensor
  34. >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
  35. >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
  36. >>> target = tensor([True, False, False, False, True, False, True])
  37. >>> hr2 = _RetrievalHitRate(top_k=2)
  38. >>> hr2(preds, target, indexes=indexes)
  39. tensor(0.5000)
  40. """
  41. def __init__(
  42. self,
  43. empty_target_action: str = "neg",
  44. ignore_index: Optional[int] = None,
  45. top_k: Optional[int] = None,
  46. **kwargs: Any,
  47. ) -> None:
  48. _deprecated_root_import_class("RetrievalHitRate", "retrieval")
  49. super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
  50. class _RetrievalMAP(RetrievalMAP):
  51. """Wrapper for deprecated import.
  52. >>> from torch import tensor
  53. >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
  54. >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
  55. >>> target = tensor([False, False, True, False, True, False, True])
  56. >>> rmap = _RetrievalMAP()
  57. >>> rmap(preds, target, indexes=indexes)
  58. tensor(0.7917)
  59. """
  60. def __init__(
  61. self,
  62. empty_target_action: str = "neg",
  63. ignore_index: Optional[int] = None,
  64. top_k: Optional[int] = None,
  65. **kwargs: Any,
  66. ) -> None:
  67. _deprecated_root_import_class("RetrievalMAP", "retrieval")
  68. super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
  69. class _RetrievalRecall(RetrievalRecall):
  70. """Wrapper for deprecated import.
  71. >>> from torch import tensor
  72. >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
  73. >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
  74. >>> target = tensor([False, False, True, False, True, False, True])
  75. >>> r2 = _RetrievalRecall(top_k=2)
  76. >>> r2(preds, target, indexes=indexes)
  77. tensor(0.7500)
  78. """
  79. def __init__(
  80. self,
  81. empty_target_action: str = "neg",
  82. ignore_index: Optional[int] = None,
  83. top_k: Optional[int] = None,
  84. **kwargs: Any,
  85. ) -> None:
  86. _deprecated_root_import_class("RetrievalRecall", "retrieval")
  87. super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
  88. class _RetrievalRPrecision(RetrievalRPrecision):
  89. """Wrapper for deprecated import.
  90. >>> from torch import tensor
  91. >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
  92. >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
  93. >>> target = tensor([False, False, True, False, True, False, True])
  94. >>> p2 = _RetrievalRPrecision()
  95. >>> p2(preds, target, indexes=indexes)
  96. tensor(0.7500)
  97. """
  98. def __init__(
  99. self,
  100. empty_target_action: str = "neg",
  101. ignore_index: Optional[int] = None,
  102. **kwargs: Any,
  103. ) -> None:
  104. _deprecated_root_import_class("RetrievalRPrecision", "retrieval")
  105. super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs)
  106. class _RetrievalNormalizedDCG(RetrievalNormalizedDCG):
  107. """Wrapper for deprecated import.
  108. >>> from torch import tensor
  109. >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
  110. >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
  111. >>> target = tensor([False, False, True, False, True, False, True])
  112. >>> ndcg = _RetrievalNormalizedDCG()
  113. >>> ndcg(preds, target, indexes=indexes)
  114. tensor(0.8467)
  115. """
  116. def __init__(
  117. self,
  118. empty_target_action: str = "neg",
  119. ignore_index: Optional[int] = None,
  120. top_k: Optional[int] = None,
  121. **kwargs: Any,
  122. ) -> None:
  123. _deprecated_root_import_class("RetrievalNormalizedDCG", "retrieval")
  124. super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, top_k=top_k, **kwargs)
  125. class _RetrievalPrecision(RetrievalPrecision):
  126. """Wrapper for deprecated import.
  127. >>> from torch import tensor
  128. >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
  129. >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
  130. >>> target = tensor([False, False, True, False, True, False, True])
  131. >>> p2 = _RetrievalPrecision(top_k=2)
  132. >>> p2(preds, target, indexes=indexes)
  133. tensor(0.5000)
  134. """
  135. def __init__(
  136. self,
  137. empty_target_action: str = "neg",
  138. ignore_index: Optional[int] = None,
  139. top_k: Optional[int] = None,
  140. adaptive_k: bool = False,
  141. **kwargs: Any,
  142. ) -> None:
  143. _deprecated_root_import_class("", "retrieval")
  144. super().__init__(
  145. empty_target_action=empty_target_action,
  146. ignore_index=ignore_index,
  147. top_k=top_k,
  148. adaptive_k=adaptive_k,
  149. **kwargs,
  150. )
  151. class _RetrievalPrecisionRecallCurve(RetrievalPrecisionRecallCurve):
  152. """Wrapper for deprecated import.
  153. >>> from torch import tensor
  154. >>> indexes = tensor([0, 0, 0, 0, 1, 1, 1])
  155. >>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5])
  156. >>> target = tensor([True, False, False, True, True, False, True])
  157. >>> r = _RetrievalPrecisionRecallCurve(max_k=4)
  158. >>> precisions, recalls, top_k = r(preds, target, indexes=indexes)
  159. >>> precisions
  160. tensor([1.0000, 0.5000, 0.6667, 0.5000])
  161. >>> recalls
  162. tensor([0.5000, 0.5000, 1.0000, 1.0000])
  163. >>> top_k
  164. tensor([1, 2, 3, 4])
  165. """
  166. def __init__(
  167. self,
  168. max_k: Optional[int] = None,
  169. adaptive_k: bool = False,
  170. empty_target_action: str = "neg",
  171. ignore_index: Optional[int] = None,
  172. **kwargs: Any,
  173. ) -> None:
  174. _deprecated_root_import_class("", "retrieval")
  175. super().__init__(
  176. max_k=max_k,
  177. adaptive_k=adaptive_k,
  178. empty_target_action=empty_target_action,
  179. ignore_index=ignore_index,
  180. **kwargs,
  181. )
  182. class _RetrievalRecallAtFixedPrecision(RetrievalRecallAtFixedPrecision):
  183. """Wrapper for deprecated import.
  184. >>> from torch import tensor
  185. >>> indexes = tensor([0, 0, 0, 0, 1, 1, 1])
  186. >>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5])
  187. >>> target = tensor([True, False, False, True, True, False, True])
  188. >>> r = _RetrievalRecallAtFixedPrecision(min_precision=0.8)
  189. >>> r(preds, target, indexes=indexes)
  190. (tensor(0.5000), tensor(1))
  191. """
  192. def __init__(
  193. self,
  194. min_precision: float = 0.0,
  195. max_k: Optional[int] = None,
  196. adaptive_k: bool = False,
  197. empty_target_action: str = "neg",
  198. ignore_index: Optional[int] = None,
  199. **kwargs: Any,
  200. ) -> None:
  201. _deprecated_root_import_class("RetrievalRecallAtFixedPrecision", "retrieval")
  202. super().__init__(
  203. min_precision=min_precision,
  204. max_k=max_k,
  205. adaptive_k=adaptive_k,
  206. empty_target_action=empty_target_action,
  207. ignore_index=ignore_index,
  208. **kwargs,
  209. )
  210. class _RetrievalMRR(RetrievalMRR):
  211. """Wrapper for deprecated import.
  212. >>> from torch import tensor
  213. >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
  214. >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
  215. >>> target = tensor([False, False, True, False, True, False, True])
  216. >>> mrr = _RetrievalMRR()
  217. >>> mrr(preds, target, indexes=indexes)
  218. tensor(0.7500)
  219. """
  220. def __init__(
  221. self,
  222. empty_target_action: str = "neg",
  223. ignore_index: Optional[int] = None,
  224. **kwargs: Any,
  225. ) -> None:
  226. _deprecated_root_import_class("", "retrieval")
  227. super().__init__(empty_target_action=empty_target_action, ignore_index=ignore_index, **kwargs)