log_mse.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Union
  15. import torch
  16. from torch import Tensor
  17. from torchmetrics.utilities.checks import _check_same_shape
  18. def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> tuple[Tensor, int]:
  19. """Return variables required to compute Mean Squared Log Error. Checks for same shape of tensors.
  20. Args:
  21. preds: Predicted tensor
  22. target: Ground truth tensor
  23. """
  24. _check_same_shape(preds, target)
  25. sum_squared_log_error = torch.sum(torch.pow(torch.log1p(preds) - torch.log1p(target), 2))
  26. return sum_squared_log_error, target.numel()
  27. def _mean_squared_log_error_compute(sum_squared_log_error: Tensor, num_obs: Union[int, Tensor]) -> Tensor:
  28. """Compute Mean Squared Log Error.
  29. Args:
  30. sum_squared_log_error:
  31. Sum of square of log errors over all observations ``(log error = log(target) - log(prediction))``
  32. num_obs: Number of predictions or observations
  33. Example:
  34. >>> preds = torch.tensor([0., 1, 2, 3])
  35. >>> target = torch.tensor([0., 1, 2, 2])
  36. >>> sum_squared_log_error, num_obs = _mean_squared_log_error_update(preds, target)
  37. >>> _mean_squared_log_error_compute(sum_squared_log_error, num_obs)
  38. tensor(0.0207)
  39. """
  40. return sum_squared_log_error / num_obs
  41. def mean_squared_log_error(preds: Tensor, target: Tensor) -> Tensor:
  42. """Compute mean squared log error.
  43. Args:
  44. preds: estimated labels
  45. target: ground truth labels
  46. Return:
  47. Tensor with RMSLE
  48. Example:
  49. >>> from torchmetrics.functional.regression import mean_squared_log_error
  50. >>> x = torch.tensor([0., 1, 2, 3])
  51. >>> y = torch.tensor([0., 1, 2, 2])
  52. >>> mean_squared_log_error(x, y)
  53. tensor(0.0207)
  54. .. attention::
  55. Half precision is only support on GPU for this metric.
  56. """
  57. sum_squared_log_error, num_obs = _mean_squared_log_error_update(preds, target)
  58. return _mean_squared_log_error_compute(sum_squared_log_error, num_obs)