minkowski.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright The PyTorch Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from torch import Tensor
  16. from torchmetrics.utilities.checks import _check_same_shape
  17. from torchmetrics.utilities.exceptions import TorchMetricsUserError
  18. def _minkowski_distance_update(preds: Tensor, targets: Tensor, p: float) -> Tensor:
  19. """Update and return variables required to compute Minkowski distance.
  20. Checks for same shape of input tensors.
  21. Args:
  22. preds: Predicted tensor
  23. targets: Ground truth tensor
  24. p: Non-negative number acting as the p to the errors
  25. """
  26. _check_same_shape(preds, targets)
  27. if not (isinstance(p, (float, int)) and p >= 1):
  28. raise TorchMetricsUserError(f"Argument ``p`` must be a float or int greater than 1, but got {p}")
  29. difference = torch.abs(preds - targets)
  30. return torch.sum(torch.pow(difference, p))
  31. def _minkowski_distance_compute(distance: Tensor, p: float) -> Tensor:
  32. """Compute Minkowski Distance.
  33. Args:
  34. distance: Sum of the p-th powers of errors over all observations
  35. p: The non-negative numeric power the errors are to be raised to
  36. Example:
  37. >>> preds = torch.tensor([0., 1, 2, 3])
  38. >>> target = torch.tensor([0., 2, 3, 1])
  39. >>> distance_p_sum = _minkowski_distance_update(preds, target, 5)
  40. >>> _minkowski_distance_compute(distance_p_sum, 5)
  41. tensor(2.0244)
  42. """
  43. return torch.pow(distance, 1.0 / p)
  44. def minkowski_distance(preds: Tensor, targets: Tensor, p: float) -> Tensor:
  45. r"""Compute the `Minkowski distance`_.
  46. .. math:: d_{\text{Minkowski}} = \\sum_{i}^N (| y_i - \\hat{y_i} |^p)^\frac{1}{p}
  47. This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski
  48. distance with p=2.
  49. Args:
  50. preds: estimated labels of type Tensor
  51. targets: ground truth labels of type Tensor
  52. p: int or float larger than 1, exponent to which the difference between preds and target is to be raised
  53. Return:
  54. Tensor with the Minkowski distance
  55. Example:
  56. >>> from torchmetrics.functional.regression import minkowski_distance
  57. >>> x = torch.tensor([1.0, 2.8, 3.5, 4.5])
  58. >>> y = torch.tensor([6.1, 2.11, 3.1, 5.6])
  59. >>> minkowski_distance(x, y, p=3)
  60. tensor(5.1220)
  61. """
  62. minkowski_dist_sum = _minkowski_distance_update(preds, targets, p)
  63. return _minkowski_distance_compute(minkowski_dist_sum, p)