precision.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor, tensor
  17. from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
  18. def retrieval_precision(preds: Tensor, target: Tensor, top_k: Optional[int] = None, adaptive_k: bool = False) -> Tensor:
  19. """Compute the precision metric for information retrieval.
  20. Precision is the fraction of relevant documents among all the retrieved documents.
  21. ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
  22. ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be ``float``,
  23. otherwise an error is raised. If you want to measure Precision@K, ``top_k`` must be a positive integer.
  24. Args:
  25. preds: estimated probabilities of each document to be relevant.
  26. target: ground truth about each document being relevant or not.
  27. top_k: consider only the top k elements (default: ``None``, which considers them all)
  28. adaptive_k: adjust `k` to `min(k, number of documents)` for each query
  29. Returns:
  30. A single-value tensor with the precision (at ``top_k``) of the predictions ``preds`` w.r.t. the labels
  31. ``target``.
  32. Raises:
  33. ValueError:
  34. If ``top_k`` is not `None` or an integer larger than 0.
  35. ValueError:
  36. If ``adaptive_k`` is not boolean.
  37. Example:
  38. >>> preds = tensor([0.2, 0.3, 0.5])
  39. >>> target = tensor([True, False, True])
  40. >>> retrieval_precision(preds, target, top_k=2)
  41. tensor(0.5000)
  42. """
  43. preds, target = _check_retrieval_functional_inputs(preds, target)
  44. if not isinstance(adaptive_k, bool):
  45. raise ValueError("`adaptive_k` has to be a boolean")
  46. if top_k is None or (adaptive_k and top_k > preds.shape[-1]):
  47. top_k = preds.shape[-1]
  48. if not (isinstance(top_k, int) and top_k > 0):
  49. raise ValueError("`top_k` has to be a positive integer or None")
  50. if not target.sum():
  51. return tensor(0.0, device=preds.device)
  52. target_filtered = torch.where(preds > 0, target, torch.zeros_like(target))
  53. relevant = target_filtered[preds.topk(min(top_k, preds.shape[-1]), dim=-1)[1]].sum().float()
  54. return relevant / top_k