gradients.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from torch import Tensor
  16. def _image_gradients_validate(img: Tensor) -> None:
  17. """Validate whether img is a 4D torch Tensor."""
  18. if not isinstance(img, Tensor):
  19. raise TypeError(f"The `img` expects a value of <Tensor> type but got {type(img)}")
  20. if img.ndim != 4:
  21. raise RuntimeError(f"The `img` expects a 4D tensor but got {img.ndim}D tensor")
  22. def _compute_image_gradients(img: Tensor) -> tuple[Tensor, Tensor]:
  23. """Compute image gradients (dy/dx) for a given image."""
  24. batch_size, channels, height, width = img.shape
  25. dy = img[..., 1:, :] - img[..., :-1, :]
  26. dx = img[..., :, 1:] - img[..., :, :-1]
  27. shapey = [batch_size, channels, 1, width]
  28. dy = torch.cat([dy, torch.zeros(shapey, device=img.device, dtype=img.dtype)], dim=2)
  29. dy = dy.view(img.shape)
  30. shapex = [batch_size, channels, height, 1]
  31. dx = torch.cat([dx, torch.zeros(shapex, device=img.device, dtype=img.dtype)], dim=3)
  32. dx = dx.view(img.shape)
  33. return dy, dx
  34. def image_gradients(img: Tensor) -> tuple[Tensor, Tensor]:
  35. """Compute `Gradient Computation of Image`_ of a given image using finite difference.
  36. Args:
  37. img: An ``(N, C, H, W)`` input tensor where ``C`` is the number of image channels
  38. Return:
  39. Tuple of ``(dy, dx)`` with each gradient of shape ``[N, C, H, W]``
  40. Raises:
  41. TypeError:
  42. If ``img`` is not of the type :class:`~torch.Tensor`.
  43. RuntimeError:
  44. If ``img`` is not a 4D tensor.
  45. Example:
  46. >>> from torchmetrics.functional.image import image_gradients
  47. >>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32)
  48. >>> image = torch.reshape(image, (1, 1, 5, 5))
  49. >>> dy, dx = image_gradients(image)
  50. >>> dy[0, 0, :, :]
  51. tensor([[5., 5., 5., 5., 5.],
  52. [5., 5., 5., 5., 5.],
  53. [5., 5., 5., 5., 5.],
  54. [5., 5., 5., 5., 5.],
  55. [0., 0., 0., 0., 0.]])
  56. .. note::
  57. The implementation follows the 1-step finite difference method as followed
  58. by the TF implementation. The values are organized such that the gradient of
  59. [I(x+1, y)-[I(x, y)]] are at the (x, y) location
  60. """
  61. _image_gradients_validate(img)
  62. return _compute_image_gradients(img)