rmse_sw.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright The PyTorch Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional, Union
  15. import torch
  16. from torch import Tensor
  17. from torchmetrics.functional.image.utils import _uniform_filter
  18. from torchmetrics.utilities.checks import _check_same_shape
  19. def _rmse_sw_update(
  20. preds: Tensor,
  21. target: Tensor,
  22. window_size: int,
  23. rmse_val_sum: Optional[Tensor],
  24. rmse_map: Optional[Tensor],
  25. total_images: Optional[Tensor],
  26. ) -> tuple[Tensor, Tensor, Tensor]:
  27. """Calculate the sum of RMSE values and RMSE map for the batch of examples and update intermediate states.
  28. Args:
  29. preds: Deformed image
  30. target: Ground truth image
  31. window_size: Sliding window used for rmse calculation
  32. rmse_val_sum: Sum of RMSE over all examples per individual channels
  33. rmse_map: Sum of RMSE map values over all examples
  34. total_images: Total number of images
  35. Return:
  36. (Optionally) Intermediate state of RMSE (using sliding window) over the accumulated examples.
  37. (Optionally) Intermediate state of RMSE map
  38. Updated total number of already processed images
  39. Raises:
  40. ValueError: If ``preds`` and ``target`` do not have the same data type.
  41. ValueError: If ``preds`` and ``target`` do not have ``BxCxWxH`` shape.
  42. ValueError: If ``round(window_size / 2)`` is greater or equal to width or height of the image.
  43. """
  44. if preds.dtype != target.dtype:
  45. raise TypeError(
  46. f"Expected `preds` and `target` to have the same data type. But got {preds.dtype} and {target.dtype}."
  47. )
  48. _check_same_shape(preds, target)
  49. if len(preds.shape) != 4:
  50. raise ValueError(f"Expected `preds` and `target` to have BxCxHxW shape. But got {preds.shape}.")
  51. if round(window_size / 2) >= target.shape[2] or round(window_size / 2) >= target.shape[3]:
  52. raise ValueError(
  53. f"Parameter `round(window_size / 2)` is expected to be smaller than {min(target.shape[2], target.shape[3])}"
  54. f" but got {round(window_size / 2)}."
  55. )
  56. if total_images is not None:
  57. total_images += target.shape[0]
  58. else:
  59. total_images = torch.tensor(target.shape[0], device=target.device)
  60. error = (target - preds) ** 2
  61. error = _uniform_filter(error, window_size)
  62. _rmse_map = torch.sqrt(error)
  63. crop_slide = round(window_size / 2)
  64. if rmse_val_sum is not None:
  65. rmse_val = _rmse_map[:, :, crop_slide:-crop_slide, crop_slide:-crop_slide]
  66. rmse_val_sum += rmse_val.sum(0).mean()
  67. else:
  68. rmse_val_sum = _rmse_map[:, :, crop_slide:-crop_slide, crop_slide:-crop_slide].sum(0).mean()
  69. if rmse_map is not None:
  70. rmse_map += _rmse_map.sum(0)
  71. else:
  72. rmse_map = _rmse_map.sum(0)
  73. return rmse_val_sum, rmse_map, total_images
  74. def _rmse_sw_compute(
  75. rmse_val_sum: Optional[Tensor], rmse_map: Tensor, total_images: Tensor
  76. ) -> tuple[Optional[Tensor], Tensor]:
  77. """Compute RMSE from the aggregated RMSE value. Optionally also computes the mean value for RMSE map.
  78. Args:
  79. rmse_val_sum: Sum of RMSE over all examples
  80. rmse_map: Sum of RMSE map values over all examples
  81. total_images: Total number of images
  82. Return:
  83. RMSE using sliding window
  84. (Optionally) RMSE map
  85. """
  86. rmse = rmse_val_sum / total_images if rmse_val_sum is not None else None
  87. if rmse_map is not None:
  88. # prevent overwrite the inputs
  89. rmse_map = rmse_map / total_images
  90. return rmse, rmse_map
  91. def root_mean_squared_error_using_sliding_window(
  92. preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False
  93. ) -> Union[Optional[Tensor], tuple[Optional[Tensor], Tensor]]:
  94. """Compute Root Mean Squared Error (RMSE) using sliding window.
  95. Args:
  96. preds: Deformed image
  97. target: Ground truth image
  98. window_size: Sliding window used for rmse calculation
  99. return_rmse_map: An indication whether the full rmse reduced image should be returned.
  100. Return:
  101. RMSE using sliding window
  102. (Optionally) RMSE map
  103. Example:
  104. >>> from torch import rand
  105. >>> from torchmetrics.functional.image import root_mean_squared_error_using_sliding_window
  106. >>> preds = rand(4, 3, 16, 16)
  107. >>> target = rand(4, 3, 16, 16)
  108. >>> root_mean_squared_error_using_sliding_window(preds, target)
  109. tensor(0.4158)
  110. Raises:
  111. ValueError: If ``window_size`` is not a positive integer.
  112. """
  113. if not isinstance(window_size, int) or (isinstance(window_size, int) and window_size < 1):
  114. raise ValueError("Argument `window_size` is expected to be a positive integer.")
  115. rmse_val_sum, rmse_map, total_images = _rmse_sw_update(
  116. preds, target, window_size, rmse_val_sum=None, rmse_map=None, total_images=None
  117. )
  118. rmse, rmse_map = _rmse_sw_compute(rmse_val_sum, rmse_map, total_images)
  119. if return_rmse_map:
  120. return rmse, rmse_map
  121. return rmse