minkowski.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright The PyTorch Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix
  19. from torchmetrics.utilities.exceptions import TorchMetricsUserError
  20. def _pairwise_minkowski_distance_update(
  21. x: Tensor, y: Optional[Tensor] = None, exponent: float = 2, zero_diagonal: Optional[bool] = None
  22. ) -> Tensor:
  23. """Calculate the pairwise minkowski distance matrix.
  24. Args:
  25. x: tensor of shape ``[N,d]``
  26. y: tensor of shape ``[M,d]``
  27. exponent: int or float larger than 1, exponent to which the difference between preds and target is to be raised
  28. zero_diagonal: determines if the diagonal of the distance matrix should be set to zero
  29. """
  30. x, y, zero_diagonal = _check_input(x, y, zero_diagonal)
  31. if not (isinstance(exponent, (float, int)) and exponent >= 1):
  32. raise TorchMetricsUserError(f"Argument ``p`` must be a float or int greater than 1, but got {exponent}")
  33. # upcast to float64 to prevent precision issues
  34. _orig_dtype = x.dtype
  35. x = x.to(torch.float64)
  36. y = y.to(torch.float64)
  37. distance = (x.unsqueeze(1) - y.unsqueeze(0)).abs().pow(exponent).sum(-1).pow(1.0 / exponent)
  38. if zero_diagonal:
  39. distance.fill_diagonal_(0)
  40. return distance.to(_orig_dtype)
  41. def pairwise_minkowski_distance(
  42. x: Tensor,
  43. y: Optional[Tensor] = None,
  44. exponent: float = 2,
  45. reduction: Literal["mean", "sum", "none", None] = None,
  46. zero_diagonal: Optional[bool] = None,
  47. ) -> Tensor:
  48. r"""Calculate pairwise minkowski distances.
  49. .. math::
  50. d_{minkowski}(x,y,p) = ||x - y||_p = \sqrt[p]{\sum_{d=1}^D (x_d - y_d)^p}
  51. If both :math:`x` and :math:`y` are passed in, the calculation will be performed pairwise between the rows of
  52. :math:`x` and :math:`y`. If only :math:`x` is passed in, the calculation will be performed between the rows
  53. of :math:`x`.
  54. Args:
  55. x: Tensor with shape ``[N, d]``
  56. y: Tensor with shape ``[M, d]``, optional
  57. exponent: int or float larger than 1, exponent to which the difference between preds and target is to be raised
  58. reduction: reduction to apply along the last dimension. Choose between `'mean'`, `'sum'`
  59. (applied along column dimension) or `'none'`, `None` for no reduction
  60. zero_diagonal: if the diagonal of the distance matrix should be set to 0. If only `x` is given
  61. this defaults to `True` else if `y` is also given it defaults to `False`
  62. Returns:
  63. A ``[N,N]`` matrix of distances if only ``x`` is given, else a ``[N,M]`` matrix
  64. Example:
  65. >>> import torch
  66. >>> from torchmetrics.functional.pairwise import pairwise_minkowski_distance
  67. >>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
  68. >>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
  69. >>> pairwise_minkowski_distance(x, y, exponent=4)
  70. tensor([[3.0092, 2.0000],
  71. [5.0317, 4.0039],
  72. [8.1222, 7.0583]])
  73. >>> pairwise_minkowski_distance(x, exponent=4)
  74. tensor([[0.0000, 2.0305, 5.1547],
  75. [2.0305, 0.0000, 3.1383],
  76. [5.1547, 3.1383, 0.0000]])
  77. """
  78. distance = _pairwise_minkowski_distance_update(x, y, exponent, zero_diagonal)
  79. return _reduce_distance_matrix(distance, reduction)