utils.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from collections.abc import Sequence
  2. from typing import Union
  3. import torch
  4. from torch import Tensor
  5. from torch.nn import functional as F # noqa: N812
  6. def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: Union[torch.device, str]) -> Tensor:
  7. """Compute 1D gaussian kernel.
  8. Args:
  9. kernel_size: size of the gaussian kernel
  10. sigma: Standard deviation of the gaussian kernel
  11. dtype: data type of the output tensor
  12. device: device of the output tensor
  13. Example:
  14. >>> _gaussian(3, 1, torch.float, 'cpu')
  15. tensor([[0.2741, 0.4519, 0.2741]])
  16. """
  17. dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
  18. gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
  19. return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
  20. def _gaussian_kernel_2d(
  21. channel: int,
  22. kernel_size: Sequence[int],
  23. sigma: Sequence[float],
  24. dtype: torch.dtype,
  25. device: Union[torch.device, str],
  26. ) -> Tensor:
  27. """Compute 2D gaussian kernel.
  28. Args:
  29. channel: number of channels in the image
  30. kernel_size: size of the gaussian kernel as a tuple (h, w)
  31. sigma: Standard deviation of the gaussian kernel
  32. dtype: data type of the output tensor
  33. device: device of the output tensor
  34. Example:
  35. >>> _gaussian_kernel_2d(1, (5,5), (1,1), torch.float, "cpu")
  36. tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
  37. [0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
  38. [0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
  39. [0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
  40. [0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])
  41. """
  42. gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
  43. gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
  44. kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
  45. return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
  46. def _uniform_weight_bias_conv2d(inputs: Tensor, window_size: int) -> tuple[Tensor, Tensor]:
  47. """Construct uniform weight and bias for a 2d convolution.
  48. Args:
  49. inputs: Input image
  50. window_size: size of convolutional kernel
  51. Return:
  52. The weight and bias for 2d convolution
  53. """
  54. kernel_weight = torch.ones(1, 1, window_size, window_size, dtype=inputs.dtype, device=inputs.device)
  55. kernel_weight /= window_size**2
  56. kernel_bias = torch.zeros(1, dtype=inputs.dtype, device=inputs.device)
  57. return kernel_weight, kernel_bias
  58. def _single_dimension_pad(inputs: Tensor, dim: int, pad: int, outer_pad: int = 0) -> Tensor:
  59. """Apply single-dimension reflection padding to match scipy implementation.
  60. Args:
  61. inputs: Input image
  62. dim: A dimension the image should be padded over
  63. pad: Number of pads
  64. outer_pad: Number of outer pads
  65. Return:
  66. Image padded over a single dimension
  67. """
  68. _max = inputs.shape[dim]
  69. x = torch.index_select(inputs, dim, torch.arange(pad - 1, -1, -1).to(inputs.device))
  70. y = torch.index_select(inputs, dim, torch.arange(_max - 1, _max - pad - outer_pad, -1).to(inputs.device))
  71. return torch.cat((x, inputs, y), dim)
  72. def _reflection_pad_2d(inputs: Tensor, pad: int, outer_pad: int = 0) -> Tensor:
  73. """Apply reflection padding to the input image.
  74. Args:
  75. inputs: Input image
  76. pad: Number of pads
  77. outer_pad: Number of outer pads
  78. Return:
  79. Padded image
  80. """
  81. for dim in [2, 3]:
  82. inputs = _single_dimension_pad(inputs, dim, pad, outer_pad)
  83. return inputs
  84. def _uniform_filter(inputs: Tensor, window_size: int) -> Tensor:
  85. """Apply uniform filter with a window of a given size over the input image.
  86. Args:
  87. inputs: Input image
  88. window_size: Sliding window used for rmse calculation
  89. Return:
  90. Image transformed with the uniform input
  91. """
  92. inputs = _reflection_pad_2d(inputs, window_size // 2, window_size % 2)
  93. kernel_weight, kernel_bias = _uniform_weight_bias_conv2d(inputs, window_size)
  94. # Iterate over channels
  95. return torch.cat(
  96. [
  97. F.conv2d(inputs[:, channel].unsqueeze(1), kernel_weight, kernel_bias, padding=0)
  98. for channel in range(inputs.shape[1])
  99. ],
  100. dim=1,
  101. )
  102. def _gaussian_kernel_3d(
  103. channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
  104. ) -> Tensor:
  105. """Compute 3D gaussian kernel.
  106. Args:
  107. channel: number of channels in the image
  108. kernel_size: size of the gaussian kernel as a tuple (h, w, d)
  109. sigma: Standard deviation of the gaussian kernel
  110. dtype: data type of the output tensor
  111. device: device of the output tensor
  112. """
  113. gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
  114. gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
  115. gaussian_kernel_z = _gaussian(kernel_size[2], sigma[2], dtype, device)
  116. kernel_xy = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
  117. kernel = torch.mul(
  118. kernel_xy.unsqueeze(-1).repeat(1, 1, kernel_size[2]),
  119. gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]),
  120. )
  121. return kernel.expand(channel, 1, kernel_size[0], kernel_size[1], kernel_size[2])
  122. def _reflection_pad_3d(inputs: Tensor, pad_h: int, pad_w: int, pad_d: int) -> Tensor:
  123. """Reflective padding of 3d input.
  124. Args:
  125. inputs: tensor to pad, should be a 3D tensor of shape ``[N, C, H, W, D]``
  126. pad_w: amount of padding in the height dimension
  127. pad_h: amount of padding in the width dimension
  128. pad_d: amount of padding in the depth dimension
  129. Returns:
  130. padded input tensor
  131. """
  132. return F.pad(inputs, (pad_h, pad_h, pad_w, pad_w, pad_d, pad_d), mode="reflect")