linear.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix
  18. from torchmetrics.utilities.compute import _safe_matmul
  19. def _pairwise_linear_similarity_update(
  20. x: Tensor, y: Optional[Tensor] = None, zero_diagonal: Optional[bool] = None
  21. ) -> Tensor:
  22. """Calculate the pairwise linear similarity matrix.
  23. Args:
  24. x: tensor of shape ``[N,d]``
  25. y: tensor of shape ``[M,d]``
  26. zero_diagonal: determines if the diagonal of the distance matrix should be set to zero
  27. """
  28. x, y, zero_diagonal = _check_input(x, y, zero_diagonal)
  29. distance = _safe_matmul(x, y)
  30. if zero_diagonal:
  31. distance.fill_diagonal_(0)
  32. return distance
  33. def pairwise_linear_similarity(
  34. x: Tensor,
  35. y: Optional[Tensor] = None,
  36. reduction: Literal["mean", "sum", "none", None] = None,
  37. zero_diagonal: Optional[bool] = None,
  38. ) -> Tensor:
  39. r"""Calculate pairwise linear similarity.
  40. .. math::
  41. s_{lin}(x,y) = <x,y> = \sum_{d=1}^D x_d \cdot y_d
  42. If both :math:`x` and :math:`y` are passed in, the calculation will be performed pairwise between
  43. the rows of :math:`x` and :math:`y`.
  44. If only :math:`x` is passed in, the calculation will be performed between the rows of :math:`x`.
  45. Args:
  46. x: Tensor with shape ``[N, d]``
  47. y: Tensor with shape ``[M, d]``, optional
  48. reduction: reduction to apply along the last dimension. Choose between `'mean'`, `'sum'`
  49. (applied along column dimension) or `'none'`, `None` for no reduction
  50. zero_diagonal: if the diagonal of the distance matrix should be set to 0. If only `x` is given
  51. this defaults to `True` else if `y` is also given it defaults to `False`
  52. Returns:
  53. A ``[N,N]`` matrix of distances if only ``x`` is given, else a ``[N,M]`` matrix
  54. Example:
  55. >>> import torch
  56. >>> from torchmetrics.functional.pairwise import pairwise_linear_similarity
  57. >>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
  58. >>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
  59. >>> pairwise_linear_similarity(x, y)
  60. tensor([[ 2., 7.],
  61. [ 3., 11.],
  62. [ 5., 18.]])
  63. >>> pairwise_linear_similarity(x)
  64. tensor([[ 0., 21., 34.],
  65. [21., 0., 55.],
  66. [34., 55., 0.]])
  67. """
  68. distance = _pairwise_linear_similarity_update(x, y, zero_diagonal)
  69. return _reduce_distance_matrix(distance, reduction)