pyramid.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import math
  19. import torch
  20. import torch.nn.functional as F
  21. from kornia.core import Module, Tensor, ones, pad, stack, tensor, zeros
  22. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
  23. from kornia.filters import filter2d, gaussian_blur2d
  24. __all__ = [
  25. "PyrDown",
  26. "PyrUp",
  27. "ScalePyramid",
  28. "build_laplacian_pyramid",
  29. "build_pyramid",
  30. "pyrdown",
  31. "pyrup",
  32. "upscale_double",
  33. ]
  34. def _get_pyramid_gaussian_kernel() -> Tensor:
  35. """Return a pre-computed gaussian kernel."""
  36. return (
  37. tensor(
  38. [
  39. [
  40. [1.0, 4.0, 6.0, 4.0, 1.0],
  41. [4.0, 16.0, 24.0, 16.0, 4.0],
  42. [6.0, 24.0, 36.0, 24.0, 6.0],
  43. [4.0, 16.0, 24.0, 16.0, 4.0],
  44. [1.0, 4.0, 6.0, 4.0, 1.0],
  45. ]
  46. ]
  47. )
  48. / 256.0
  49. )
  50. class PyrDown(Module):
  51. r"""Blur a tensor and downsamples it.
  52. Args:
  53. border_type: the padding mode to be applied before convolving.
  54. The expected modes are: ``'constant'``, ``'reflect'``,
  55. ``'replicate'`` or ``'circular'``.
  56. align_corners: interpolation flag.
  57. factor: the downsampling factor
  58. Return:
  59. the downsampled tensor.
  60. Shape:
  61. - Input: :math:`(B, C, H, W)`
  62. - Output: :math:`(B, C, H / 2, W / 2)`
  63. Examples:
  64. >>> input = torch.rand(1, 2, 4, 4)
  65. >>> output = PyrDown()(input) # 1x2x2x2
  66. """
  67. def __init__(self, border_type: str = "reflect", align_corners: bool = False, factor: float = 2.0) -> None:
  68. super().__init__()
  69. self.border_type: str = border_type
  70. self.align_corners: bool = align_corners
  71. self.factor: float = factor
  72. def forward(self, input: Tensor) -> Tensor:
  73. return pyrdown(input, self.border_type, self.align_corners, self.factor)
  74. class PyrUp(Module):
  75. r"""Upsample a tensor and then blurs it.
  76. Args:
  77. borde_type: the padding mode to be applied before convolving.
  78. The expected modes are: ``'constant'``, ``'reflect'``,
  79. ``'replicate'`` or ``'circular'``.
  80. align_corners: interpolation flag.
  81. Return:
  82. the upsampled tensor.
  83. Shape:
  84. - Input: :math:`(B, C, H, W)`
  85. - Output: :math:`(B, C, H * 2, W * 2)`
  86. Examples:
  87. >>> input = torch.rand(1, 2, 4, 4)
  88. >>> output = PyrUp()(input) # 1x2x8x8
  89. """
  90. def __init__(self, border_type: str = "reflect", align_corners: bool = False) -> None:
  91. super().__init__()
  92. self.border_type: str = border_type
  93. self.align_corners: bool = align_corners
  94. def forward(self, input: Tensor) -> Tensor:
  95. return pyrup(input, self.border_type, self.align_corners)
  96. class ScalePyramid(Module):
  97. r"""Create an scale pyramid of image, usually used for local feature detection.
  98. Images are consequently smoothed with Gaussian blur and downscaled.
  99. Args:
  100. n_levels: number of the levels in octave.
  101. init_sigma: initial blur level.
  102. min_size: the minimum size of the octave in pixels.
  103. double_image: add 2x upscaled image as 1st level of pyramid. OpenCV SIFT does this.
  104. Returns:
  105. 1st output: images
  106. 2nd output: sigmas (coefficients for scale conversion)
  107. 3rd output: pixelDists (coefficients for coordinate conversion)
  108. Shape:
  109. - Input: :math:`(B, C, H, W)`
  110. - Output 1st: :math:`[(B, C, NL, H, W), (B, C, NL, H/2, W/2), ...]`
  111. - Output 2nd: :math:`[(B, NL), (B, NL), (B, NL), ...]`
  112. - Output 3rd: :math:`[(B, NL), (B, NL), (B, NL), ...]`
  113. Examples:
  114. >>> input = torch.rand(2, 4, 100, 100)
  115. >>> sp, sigmas, pds = ScalePyramid(3, 15)(input)
  116. """
  117. def __init__(
  118. self, n_levels: int = 3, init_sigma: float = 1.6, min_size: int = 15, double_image: bool = False
  119. ) -> None:
  120. super().__init__()
  121. # 3 extra levels are needed for DoG nms.
  122. self.n_levels = n_levels
  123. self.extra_levels: int = 3
  124. self.init_sigma = init_sigma
  125. self.min_size = min_size
  126. self.border = min_size // 2 - 1
  127. self.sigma_step = 2 ** (1.0 / float(self.n_levels))
  128. self.double_image = double_image
  129. def __repr__(self) -> str:
  130. return (
  131. f"{self.__class__.__name__}("
  132. f"n_levels={self.n_levels}, "
  133. f"init_sigma={self.init_sigma}, "
  134. f"min_size={self.min_size}, "
  135. f"extra_levels={self.extra_levels}, "
  136. f"border={self.border}, "
  137. f"sigma_step={self.sigma_step}, "
  138. f"double_image={self.double_image})"
  139. )
  140. def get_kernel_size(self, sigma: float) -> int:
  141. ksize = int(2.0 * 4.0 * sigma + 1.0)
  142. # matches OpenCV, but may cause padding problem for small images
  143. # PyTorch does not allow to pad more than original size.
  144. # Therefore there is a hack in forward function
  145. if ksize % 2 == 0:
  146. ksize += 1
  147. return ksize
  148. def get_first_level(self, input: Tensor) -> tuple[Tensor, float, float]:
  149. pixel_distance = 1.0
  150. cur_sigma = 0.5
  151. # Same as in OpenCV up to interpolation difference
  152. if self.double_image:
  153. x = upscale_double(input)
  154. pixel_distance = 0.5
  155. cur_sigma *= 2.0
  156. else:
  157. x = input
  158. if self.init_sigma > cur_sigma:
  159. sigma = max(math.sqrt(self.init_sigma**2 - cur_sigma**2), 0.01)
  160. ksize = self.get_kernel_size(sigma)
  161. cur_level = gaussian_blur2d(x, (ksize, ksize), (sigma, sigma))
  162. cur_sigma = self.init_sigma
  163. else:
  164. cur_level = x
  165. return cur_level, cur_sigma, pixel_distance
  166. def forward(self, x: Tensor) -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
  167. bs, _, _, _ = x.size()
  168. cur_level, cur_sigma, pixel_distance = self.get_first_level(x)
  169. sigmas = [cur_sigma * ones(bs, self.n_levels + self.extra_levels).to(x.device).to(x.dtype)]
  170. pixel_dists = [pixel_distance * ones(bs, self.n_levels + self.extra_levels).to(x.device).to(x.dtype)]
  171. pyr = [[cur_level]]
  172. oct_idx = 0
  173. while True:
  174. cur_level = pyr[-1][0]
  175. for level_idx in range(1, self.n_levels + self.extra_levels):
  176. sigma = cur_sigma * math.sqrt(self.sigma_step**2 - 1.0)
  177. ksize = self.get_kernel_size(sigma)
  178. # Hack, because PyTorch does not allow to pad more than original size.
  179. # But for the huge sigmas, one needs huge kernel and padding...
  180. ksize = min(ksize, cur_level.size(2), cur_level.size(3))
  181. if ksize % 2 == 0:
  182. ksize += 1
  183. cur_level = gaussian_blur2d(cur_level, (ksize, ksize), (sigma, sigma))
  184. cur_sigma *= self.sigma_step
  185. pyr[-1].append(cur_level)
  186. sigmas[-1][:, level_idx] = cur_sigma
  187. pixel_dists[-1][:, level_idx] = pixel_distance
  188. _pyr = pyr[-1][-self.extra_levels]
  189. nextOctaveFirstLevel = _pyr[:, :, ::2, ::2]
  190. pixel_distance *= 2.0
  191. cur_sigma = self.init_sigma
  192. if min(nextOctaveFirstLevel.size(2), nextOctaveFirstLevel.size(3)) <= self.min_size:
  193. break
  194. pyr.append([nextOctaveFirstLevel])
  195. sigmas.append(cur_sigma * torch.ones(bs, self.n_levels + self.extra_levels).to(x.device))
  196. pixel_dists.append(pixel_distance * torch.ones(bs, self.n_levels + self.extra_levels).to(x.device))
  197. oct_idx += 1
  198. output_pyr = [stack(i, 2) for i in pyr]
  199. return output_pyr, sigmas, pixel_dists
  200. def pyrdown(input: Tensor, border_type: str = "reflect", align_corners: bool = False, factor: float = 2.0) -> Tensor:
  201. r"""Blur a tensor and downsamples it.
  202. .. image:: _static/img/pyrdown.png
  203. Args:
  204. input: the tensor to be downsampled.
  205. border_type: the padding mode to be applied before convolving.
  206. The expected modes are: ``'constant'``, ``'reflect'``,
  207. ``'replicate'`` or ``'circular'``.
  208. align_corners: interpolation flag.
  209. factor: the downsampling factor
  210. Return:
  211. the downsampled tensor.
  212. Examples:
  213. >>> input = torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4)
  214. >>> pyrdown(input, align_corners=True)
  215. tensor([[[[ 3.7500, 5.2500],
  216. [ 9.7500, 11.2500]]]])
  217. """
  218. KORNIA_CHECK_SHAPE(input, ["B", "C", "H", "W"])
  219. kernel: Tensor = _get_pyramid_gaussian_kernel()
  220. _, _, height, width = input.shape
  221. # blur image
  222. x_blur: Tensor = filter2d(input, kernel, border_type)
  223. # TODO: use kornia.geometry.resize/rescale
  224. # downsample.
  225. out: Tensor = F.interpolate(
  226. x_blur,
  227. size=(int(float(height) / factor), int(float(width) // factor)),
  228. mode="bilinear",
  229. align_corners=align_corners,
  230. )
  231. return out
  232. def pyrup(input: Tensor, border_type: str = "reflect", align_corners: bool = False) -> Tensor:
  233. r"""Upsample a tensor and then blurs it.
  234. .. image:: _static/img/pyrup.png
  235. Args:
  236. input: the tensor to be downsampled.
  237. border_type: the padding mode to be applied before convolving.
  238. The expected modes are: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
  239. align_corners: interpolation flag.
  240. Return:
  241. the downsampled tensor.
  242. Examples:
  243. >>> input = torch.arange(4, dtype=torch.float32).reshape(1, 1, 2, 2)
  244. >>> pyrup(input, align_corners=True)
  245. tensor([[[[0.7500, 0.8750, 1.1250, 1.2500],
  246. [1.0000, 1.1250, 1.3750, 1.5000],
  247. [1.5000, 1.6250, 1.8750, 2.0000],
  248. [1.7500, 1.8750, 2.1250, 2.2500]]]])
  249. """
  250. KORNIA_CHECK_SHAPE(input, ["B", "C", "H", "W"])
  251. kernel: Tensor = _get_pyramid_gaussian_kernel()
  252. # upsample tensor
  253. _, _, height, width = input.shape
  254. # TODO: use kornia.geometry.resize/rescale
  255. x_up: Tensor = F.interpolate(input, size=(height * 2, width * 2), mode="bilinear", align_corners=align_corners)
  256. # blurs upsampled tensor
  257. x_blur: Tensor = filter2d(x_up, kernel, border_type)
  258. return x_blur
  259. def build_pyramid(
  260. input: Tensor, max_level: int, border_type: str = "reflect", align_corners: bool = False
  261. ) -> list[Tensor]:
  262. r"""Construct the Gaussian pyramid for a tensor image.
  263. .. image:: _static/img/build_pyramid.png
  264. The function constructs a vector of images and builds the Gaussian pyramid
  265. by recursively applying pyrDown to the previously built pyramid layers.
  266. Args:
  267. input : the tensor to be used to construct the pyramid.
  268. max_level: 0-based index of the last (the smallest) pyramid layer.
  269. It must be non-negative.
  270. border_type: the padding mode to be applied before convolving.
  271. The expected modes are: ``'constant'``, ``'reflect'``,
  272. ``'replicate'`` or ``'circular'``.
  273. align_corners: interpolation flag.
  274. Shape:
  275. - Input: :math:`(B, C, H, W)`
  276. - Output :math:`[(B, C, H, W), (B, C, H/2, W/2), ...]`
  277. """
  278. KORNIA_CHECK_SHAPE(input, ["B", "C", "H", "W"])
  279. KORNIA_CHECK(
  280. isinstance(max_level, int) or max_level < 0,
  281. f"Invalid max_level, it must be a positive integer. Got: {max_level}",
  282. )
  283. # create empty list and append the original image
  284. pyramid: list[Tensor] = []
  285. pyramid.append(input)
  286. # iterate and downsample
  287. for _ in range(max_level - 1):
  288. img_curr: Tensor = pyramid[-1]
  289. img_down: Tensor = pyrdown(img_curr, border_type, align_corners)
  290. pyramid.append(img_down)
  291. return pyramid
  292. def is_powerof_two(x: int) -> bool:
  293. # check if number x is a power of two
  294. return bool(x) and (not (x & (x - 1)))
  295. def find_next_powerof_two(x: int) -> int:
  296. return 1 << (x - 1).bit_length()
  297. def build_laplacian_pyramid(
  298. input: Tensor, max_level: int, border_type: str = "reflect", align_corners: bool = False
  299. ) -> list[Tensor]:
  300. r"""Construct the Laplacian pyramid for a tensor image.
  301. The function constructs a vector of images and builds the Laplacian pyramid
  302. by recursively computing the difference after applying
  303. pyrUp to the adjacent layer in its Gaussian pyramid.
  304. See :cite:`burt1987laplacian` for more details.
  305. Args:
  306. input : the tensor to be used to construct the pyramid with shape :math:`(B, C, H, W)`.
  307. max_level: 0-based index of the last (the smallest) pyramid layer.
  308. It must be non-negative.
  309. border_type: the padding mode to be applied before convolving.
  310. The expected modes are: ``'constant'``, ``'reflect'``,
  311. ``'replicate'`` or ``'circular'``.
  312. align_corners: interpolation flag.
  313. Return:
  314. Output: :math:`[(B, C, H, W), (B, C, H/2, W/2), ...]`
  315. """
  316. KORNIA_CHECK_SHAPE(input, ["B", "C", "H", "W"])
  317. KORNIA_CHECK(
  318. isinstance(max_level, int) or max_level < 0,
  319. f"Invalid max_level, it must be a positive integer. Got: {max_level}",
  320. )
  321. h = input.size()[2]
  322. w = input.size()[3]
  323. require_padding = not (is_powerof_two(w) or is_powerof_two(h))
  324. if require_padding:
  325. # in case of arbitrary shape tensor image need to be padded.
  326. # Reference: https://stackoverflow.com/a/29967555
  327. padding = (0, find_next_powerof_two(w) - w, 0, find_next_powerof_two(h) - h)
  328. input = pad(input, padding, "reflect")
  329. # create gaussian pyramid
  330. gaussian_pyramid: list[Tensor] = build_pyramid(input, max_level, border_type, align_corners)
  331. # create empty list
  332. laplacian_pyramid: list[Tensor] = []
  333. # iterate and compute difference of adjacent layers in a gaussian pyramid
  334. for i in range(max_level - 1):
  335. img_expand: Tensor = pyrup(gaussian_pyramid[i + 1], border_type, align_corners)
  336. laplacian: Tensor = gaussian_pyramid[i] - img_expand
  337. laplacian_pyramid.append(laplacian)
  338. laplacian_pyramid.append(gaussian_pyramid[-1])
  339. return laplacian_pyramid
  340. def upscale_double(x: Tensor) -> Tensor:
  341. r"""Upscale image by the factor of 2, even indices maps to original indices.
  342. Odd indices are linearly interpolated from the even ones.
  343. Args:
  344. x: input image.
  345. Shape:
  346. - Input: :math:`(*, H, W)`
  347. - Output :math:`(*, H, W)`
  348. """
  349. KORNIA_CHECK_IS_TENSOR(x)
  350. KORNIA_CHECK_SHAPE(x, ["*", "H", "W"])
  351. double_shape = x.shape[:-2] + (x.shape[-2] * 2, x.shape[-1] * 2)
  352. upscaled = zeros(double_shape, device=x.device, dtype=x.dtype)
  353. upscaled[..., ::2, ::2] = x
  354. upscaled[..., ::2, 1::2][..., :-1] = (upscaled[..., ::2, ::2][..., :-1] + upscaled[..., ::2, 2::2]) / 2
  355. upscaled[..., ::2, -1] = upscaled[..., ::2, -2]
  356. upscaled[..., 1::2, :][..., :-1, :] = (upscaled[..., ::2, :][..., :-1, :] + upscaled[..., 2::2, :]) / 2
  357. upscaled[..., -1, :] = upscaled[..., -2, :]
  358. return upscaled