| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- # Copyright The PyTorch Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- from torch import Tensor
- from torchmetrics.utilities.checks import _check_same_shape
- from torchmetrics.utilities.exceptions import TorchMetricsUserError
- def _minkowski_distance_update(preds: Tensor, targets: Tensor, p: float) -> Tensor:
- """Update and return variables required to compute Minkowski distance.
- Checks for same shape of input tensors.
- Args:
- preds: Predicted tensor
- targets: Ground truth tensor
- p: Non-negative number acting as the p to the errors
- """
- _check_same_shape(preds, targets)
- if not (isinstance(p, (float, int)) and p >= 1):
- raise TorchMetricsUserError(f"Argument ``p`` must be a float or int greater than 1, but got {p}")
- difference = torch.abs(preds - targets)
- return torch.sum(torch.pow(difference, p))
- def _minkowski_distance_compute(distance: Tensor, p: float) -> Tensor:
- """Compute Minkowski Distance.
- Args:
- distance: Sum of the p-th powers of errors over all observations
- p: The non-negative numeric power the errors are to be raised to
- Example:
- >>> preds = torch.tensor([0., 1, 2, 3])
- >>> target = torch.tensor([0., 2, 3, 1])
- >>> distance_p_sum = _minkowski_distance_update(preds, target, 5)
- >>> _minkowski_distance_compute(distance_p_sum, 5)
- tensor(2.0244)
- """
- return torch.pow(distance, 1.0 / p)
- def minkowski_distance(preds: Tensor, targets: Tensor, p: float) -> Tensor:
- r"""Compute the `Minkowski distance`_.
- .. math:: d_{\text{Minkowski}} = \\sum_{i}^N (| y_i - \\hat{y_i} |^p)^\frac{1}{p}
- This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski
- distance with p=2.
- Args:
- preds: estimated labels of type Tensor
- targets: ground truth labels of type Tensor
- p: int or float larger than 1, exponent to which the difference between preds and target is to be raised
- Return:
- Tensor with the Minkowski distance
- Example:
- >>> from torchmetrics.functional.regression import minkowski_distance
- >>> x = torch.tensor([1.0, 2.8, 3.5, 4.5])
- >>> y = torch.tensor([6.1, 2.11, 3.1, 5.6])
- >>> minkowski_distance(x, y, p=3)
- tensor(5.1220)
- """
- minkowski_dist_sum = _minkowski_distance_update(preds, targets, p)
- return _minkowski_distance_compute(minkowski_dist_sum, p)
|