cosine.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix
  19. from torchmetrics.utilities.compute import _safe_matmul
  20. def _pairwise_cosine_similarity_update(
  21. x: Tensor, y: Optional[Tensor] = None, zero_diagonal: Optional[bool] = None
  22. ) -> Tensor:
  23. """Calculate the pairwise cosine similarity matrix.
  24. Args:
  25. x: tensor of shape ``[N,d]``
  26. y: tensor of shape ``[M,d]``
  27. zero_diagonal: determines if the diagonal of the distance matrix should be set to zero
  28. """
  29. x, y, zero_diagonal = _check_input(x, y, zero_diagonal)
  30. norm = torch.norm(x, p=2, dim=1)
  31. x = x / norm.unsqueeze(1)
  32. norm = torch.norm(y, p=2, dim=1)
  33. y = y / norm.unsqueeze(1)
  34. distance = _safe_matmul(x, y)
  35. if zero_diagonal:
  36. distance.fill_diagonal_(0)
  37. return distance
  38. def pairwise_cosine_similarity(
  39. x: Tensor,
  40. y: Optional[Tensor] = None,
  41. reduction: Literal["mean", "sum", "none", None] = None,
  42. zero_diagonal: Optional[bool] = None,
  43. ) -> Tensor:
  44. r"""Calculate pairwise cosine similarity.
  45. .. math::
  46. s_{cos}(x,y) = \frac{<x,y>}{||x|| \cdot ||y||}
  47. = \frac{\sum_{d=1}^D x_d \cdot y_d }{\sqrt{\sum_{d=1}^D x_i^2} \cdot \sqrt{\sum_{d=1}^D y_i^2}}
  48. If both :math:`x` and :math:`y` are passed in, the calculation will be performed pairwise
  49. between the rows of :math:`x` and :math:`y`.
  50. If only :math:`x` is passed in, the calculation will be performed between the rows of :math:`x`.
  51. Args:
  52. x: Tensor with shape ``[N, d]``
  53. y: Tensor with shape ``[M, d]``, optional
  54. reduction: reduction to apply along the last dimension. Choose between `'mean'`, `'sum'`
  55. (applied along column dimension) or `'none'`, `None` for no reduction
  56. zero_diagonal: if the diagonal of the distance matrix should be set to 0. If only :math:`x` is given
  57. this defaults to ``True`` else if :math:`y` is also given it defaults to ``False``
  58. Returns:
  59. A ``[N,N]`` matrix of distances if only ``x`` is given, else a ``[N,M]`` matrix
  60. Example:
  61. >>> import torch
  62. >>> from torchmetrics.functional.pairwise import pairwise_cosine_similarity
  63. >>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
  64. >>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
  65. >>> pairwise_cosine_similarity(x, y)
  66. tensor([[0.5547, 0.8682],
  67. [0.5145, 0.8437],
  68. [0.5300, 0.8533]])
  69. >>> pairwise_cosine_similarity(x)
  70. tensor([[0.0000, 0.9989, 0.9996],
  71. [0.9989, 0.0000, 0.9998],
  72. [0.9996, 0.9998, 0.0000]])
  73. """
  74. distance = _pairwise_cosine_similarity_update(x, y, zero_diagonal)
  75. return _reduce_distance_matrix(distance, reduction)