extract_patches.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from math import ceil
  18. from typing import Optional, Tuple, Union, cast
  19. from warnings import warn
  20. import torch
  21. import torch.nn.functional as F
  22. from torch.nn.modules.utils import _pair
  23. from kornia.core import Module, Tensor, pad
  24. FullPadType = Tuple[int, int, int, int]
  25. TuplePadType = Union[Tuple[int, int], FullPadType]
  26. PadType = Union[int, TuplePadType]
  27. def create_padding_tuple(padding: PadType, unpadding: bool = False) -> FullPadType:
  28. """Create argument for padding op."""
  29. padding = cast(TuplePadType, _pair(padding))
  30. if len(padding) not in [2, 4]:
  31. raise AssertionError(
  32. f"{'Unpadding' if unpadding else 'Padding'} must be either an int, tuple of two ints or tuple of four ints"
  33. )
  34. if len(padding) == 2:
  35. pad_vert = _pair(padding[0])
  36. pad_horz = _pair(padding[1])
  37. else:
  38. pad_vert = padding[:2]
  39. pad_horz = padding[2:]
  40. padding = cast(FullPadType, pad_horz + pad_vert)
  41. return padding
  42. def compute_padding(
  43. original_size: Union[int, Tuple[int, int]],
  44. window_size: Union[int, Tuple[int, int]],
  45. stride: Optional[Union[int, Tuple[int, int]]] = None,
  46. ) -> FullPadType:
  47. r"""Compute required padding to ensure chaining of :func:`extract_tensor_patches` and
  48. :func:`combine_tensor_patches` produces expected result.
  49. Args:
  50. original_size: the size of the original tensor.
  51. window_size: the size of the sliding window used while extracting patches.
  52. stride: The stride of the sliding window. Optional: if not specified, window_size will be used.
  53. Return:
  54. The required padding as a tuple of four ints: (top, bottom, left, right)
  55. Example:
  56. >>> image = torch.arange(12).view(1, 1, 4, 3)
  57. >>> padding = compute_padding((4,3), (3,3))
  58. >>> out = extract_tensor_patches(image, window_size=(3, 3), stride=(3, 3), padding=padding)
  59. >>> combine_tensor_patches(out, original_size=(4, 3), window_size=(3, 3), stride=(3, 3), unpadding=padding)
  60. tensor([[[[ 0, 1, 2],
  61. [ 3, 4, 5],
  62. [ 6, 7, 8],
  63. [ 9, 10, 11]]]])
  64. .. note::
  65. This function will be implicitly used in :func:`extract_tensor_patches` and :func:`combine_tensor_patches` if
  66. `allow_auto_(un)padding` is set to True.
  67. """ # noqa: D205
  68. original_size = cast(Tuple[int, int], _pair(original_size))
  69. window_size = cast(Tuple[int, int], _pair(window_size))
  70. if stride is None:
  71. stride = window_size
  72. stride = cast(Tuple[int, int], _pair(stride))
  73. remainder_vertical = (original_size[0] - window_size[0]) % stride[0]
  74. remainder_horizontal = (original_size[1] - window_size[1]) % stride[1]
  75. # it might be best to apply padding only to the far edges (right, bottom), so
  76. # that fewer patches are affected by the padding.
  77. # For now, just use the default padding
  78. if remainder_vertical != 0:
  79. vertical_padding = stride[0] - remainder_vertical
  80. else:
  81. vertical_padding = 0
  82. if remainder_horizontal != 0:
  83. horizontal_padding = stride[1] - remainder_horizontal
  84. else:
  85. horizontal_padding = 0
  86. if vertical_padding % 2 == 0:
  87. top_padding = bottom_padding = vertical_padding // 2
  88. else:
  89. top_padding = vertical_padding // 2
  90. bottom_padding = ceil(vertical_padding / 2)
  91. if horizontal_padding % 2 == 0:
  92. left_padding = right_padding = horizontal_padding // 2
  93. else:
  94. left_padding = horizontal_padding // 2
  95. right_padding = ceil(horizontal_padding / 2)
  96. # the new implementation with unfolding requires symmetric padding
  97. padding = int(top_padding), int(bottom_padding), int(left_padding), int(right_padding)
  98. return padding
  99. class ExtractTensorPatches(Module):
  100. r"""Module that extract patches from tensors and stack them.
  101. In the simplest case, the output value of the operator with input size
  102. :math:`(B, C, H, W)` is :math:`(B, N, C, H_{out}, W_{out})`.
  103. where
  104. - :math:`B` is the batch size.
  105. - :math:`N` denotes the total number of extracted patches stacked in
  106. - :math:`C` denotes the number of input channels.
  107. - :math:`H`, :math:`W` the input height and width of the input in pixels.
  108. - :math:`H_{out}`, :math:`W_{out}` denote to denote to the patch size
  109. defined in the function signature.
  110. left-right and top-bottom order.
  111. * :attr:`window_size` is the size of the sliding window and controls the
  112. shape of the output tensor and defines the shape of the output patch.
  113. * :attr:`stride` controls the stride to apply to the sliding window and
  114. regulates the overlapping between the extracted patches.
  115. * :attr:`padding` controls the amount of implicit zeros-paddings on both
  116. sizes at each dimension.
  117. * :attr:`allow_auto_padding` allows automatic calculation of the padding required
  118. to fit the window and stride into the image.
  119. The parameters :attr:`window_size`, :attr:`stride` and :attr:`padding` can
  120. be either:
  121. - a single ``int`` -- in which case the same value is used for the
  122. height and width dimension.
  123. - a ``tuple`` of two ints -- in which case, the first `int` is used for
  124. the height dimension, and the second `int` for the width dimension.
  125. :attr:`padding` can also be a ``tuple`` of four ints -- in which case, the
  126. first two ints are for the height dimension while the last two ints are for
  127. the width dimension.
  128. Args:
  129. input: tensor image where to extract the patches with shape :math:`(B, C, H, W)`.
  130. window_size: the size of the sliding window and the output patch size.
  131. stride: stride of the sliding window.
  132. padding: Zero-padding added to both side of the input.
  133. allow_auto_adding: whether to allow automatic padding if the window and stride do not fit into the image.
  134. Shape:
  135. - Input: :math:`(B, C, H, W)`
  136. - Output: :math:`(B, N, C, H_{out}, W_{out})`
  137. Returns:
  138. the tensor with the extracted patches.
  139. Examples:
  140. >>> input = torch.arange(9.).view(1, 1, 3, 3)
  141. >>> patches = extract_tensor_patches(input, (2, 3))
  142. >>> input
  143. tensor([[[[0., 1., 2.],
  144. [3., 4., 5.],
  145. [6., 7., 8.]]]])
  146. >>> patches[:, -1]
  147. tensor([[[[3., 4., 5.],
  148. [6., 7., 8.]]]])
  149. """
  150. def __init__(
  151. self,
  152. window_size: Union[int, Tuple[int, int]],
  153. stride: Union[int, Tuple[int, int]] = 1,
  154. padding: PadType = 0,
  155. allow_auto_padding: bool = False,
  156. ) -> None:
  157. super().__init__()
  158. self.window_size: Union[int, Tuple[int, int]] = window_size
  159. self.stride: Union[int, Tuple[int, int]] = stride
  160. self.padding: PadType = padding
  161. self.allow_auto_padding: bool = allow_auto_padding
  162. def forward(self, input: Tensor) -> Tensor:
  163. return extract_tensor_patches(
  164. input,
  165. self.window_size,
  166. stride=self.stride,
  167. padding=self.padding,
  168. allow_auto_padding=self.allow_auto_padding,
  169. )
  170. class CombineTensorPatches(Module):
  171. r"""Module that combines patches back into full tensors.
  172. In the simplest case, the output value of the operator with input size
  173. :math:`(B, N, C, H_{out}, W_{out})` is :math:`(B, C, H, W)`.
  174. where
  175. - :math:`B` is the batch size.
  176. - :math:`N` denotes the total number of extracted patches stacked in
  177. - :math:`C` denotes the number of input channels.
  178. - :math:`H`, :math:`W` the input height and width of the input in pixels.
  179. - :math:`H_{out}`, :math:`W_{out}` denote to denote to the patch size
  180. defined in the function signature.
  181. left-right and top-bottom order.
  182. * :attr:`original_size` is the size of the original image prior to
  183. extracting tensor patches and defines the shape of the output patch.
  184. * :attr:`window_size` is the size of the sliding window used while
  185. extracting tensor patches.
  186. * :attr:`stride` controls the stride to apply to the sliding window and
  187. regulates the overlapping between the extracted patches.
  188. * :attr:`unpadding` is the amount of padding to be removed. If specified,
  189. this value must be the same as padding used while extracting tensor patches.
  190. * :attr:`allow_auto_unpadding` allows automatic calculation of the padding required
  191. to fit the window and stride into the image. This must be used if the
  192. `allow_auto_padding` flag was used for extracting the patches.
  193. The parameters :attr:`original_size`, :attr:`window_size`, :attr:`stride`, and :attr:`unpadding` can
  194. be either:
  195. - a single ``int`` -- in which case the same value is used for the
  196. height and width dimension.
  197. - a ``tuple`` of two ints -- in which case, the first `int` is used for
  198. the height dimension, and the second `int` for the width dimension.
  199. :attr:`unpadding` can also be a ``tuple`` of four ints -- in which case, the
  200. first two ints are for the height dimension while the last two ints are for
  201. the width dimension.
  202. Args:
  203. patches: patched tensor with shape :math:`(B, N, C, H_{out}, W_{out})`.
  204. original_size: the size of the original tensor and the output size.
  205. window_size: the size of the sliding window used while extracting patches.
  206. stride: stride of the sliding window.
  207. unpadding: remove the padding added to both side of the input.
  208. allow_auto_unpadding: whether to allow automatic unpadding of the input
  209. if the window and stride do not fit into the original_size.
  210. eps: small value used to prevent division by zero.
  211. Shape:
  212. - Input: :math:`(B, N, C, H_{out}, W_{out})`
  213. - Output: :math:`(B, C, H, W)`
  214. Example:
  215. >>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2))
  216. >>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2))
  217. tensor([[[[ 0, 1, 2, 3],
  218. [ 4, 5, 6, 7],
  219. [ 8, 9, 10, 11],
  220. [12, 13, 14, 15]]]])
  221. .. note::
  222. This function is supposed to be used in conjunction with :class:`ExtractTensorPatches`.
  223. """
  224. def __init__(
  225. self,
  226. original_size: Tuple[int, int],
  227. window_size: Union[int, Tuple[int, int]],
  228. stride: Optional[Union[int, Tuple[int, int]]] = None,
  229. unpadding: PadType = 0,
  230. allow_auto_unpadding: bool = False,
  231. ) -> None:
  232. super().__init__()
  233. self.original_size: Tuple[int, int] = original_size
  234. self.window_size: Union[int, Tuple[int, int]] = window_size
  235. self.stride: Union[int, Tuple[int, int]] = stride if stride is not None else window_size
  236. self.unpadding: PadType = unpadding
  237. self.allow_auto_unpadding: bool = allow_auto_unpadding
  238. def forward(self, input: Tensor) -> Tensor:
  239. return combine_tensor_patches(
  240. input,
  241. self.original_size,
  242. self.window_size,
  243. stride=self.stride,
  244. unpadding=self.unpadding,
  245. allow_auto_unpadding=self.allow_auto_unpadding,
  246. )
  247. def _check_patch_fit(original_size: Tuple[int, int], window_size: Tuple[int, int], stride: Tuple[int, int]) -> bool:
  248. remainder_vertical = (original_size[0] - window_size[0]) % stride[0]
  249. remainder_horizontal = (original_size[1] - window_size[1]) % stride[1]
  250. # the remainder takes into account half a window on each side,
  251. # the rest of the image is divided based on the stride, not the window
  252. # size
  253. if (remainder_horizontal != 0) or (remainder_vertical != 0):
  254. # needs padding to fit
  255. return False
  256. # we can fit a full number of patches in, based on the stride
  257. return True
  258. def combine_tensor_patches(
  259. patches: Tensor,
  260. original_size: Union[int, Tuple[int, int]],
  261. window_size: Union[int, Tuple[int, int]],
  262. stride: Union[int, Tuple[int, int]],
  263. allow_auto_unpadding: bool = False,
  264. unpadding: PadType = 0,
  265. eps: float = 1e-8,
  266. ) -> Tensor:
  267. r"""Restore input from patches.
  268. See :class:`~kornia.contrib.CombineTensorPatches` for details.
  269. Args:
  270. patches: patched tensor with shape :math:`(B, N, C, H_{out}, W_{out})`.
  271. original_size: the size of the original tensor and the output size.
  272. window_size: the size of the sliding window used while extracting patches.
  273. stride: stride of the sliding window.
  274. unpadding: remove the padding added to both side of the input.
  275. allow_auto_unpadding: whether to allow automatic unpadding of the input
  276. if the window and stride do not fit into the original_size.
  277. eps: small value used to prevent division by zero.
  278. Return:
  279. The combined patches in an image tensor with shape :math:`(B, C, H, W)`.
  280. Example:
  281. >>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2))
  282. >>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2))
  283. tensor([[[[ 0, 1, 2, 3],
  284. [ 4, 5, 6, 7],
  285. [ 8, 9, 10, 11],
  286. [12, 13, 14, 15]]]])
  287. .. note::
  288. This function is supposed to be used in conjunction with :func:`extract_tensor_patches`.
  289. """
  290. if patches.ndim != 5:
  291. raise ValueError(f"Invalid input shape, we expect BxNxCxHxW. Got: {patches.shape}")
  292. original_size = cast(Tuple[int, int], _pair(original_size))
  293. window_size = cast(Tuple[int, int], _pair(window_size))
  294. stride = cast(Tuple[int, int], _pair(stride))
  295. if (stride[0] > window_size[0]) | (stride[1] > window_size[1]):
  296. raise AssertionError(
  297. f"Stride={stride} should be less than or equal to Window size={window_size}, information is missing"
  298. )
  299. if not unpadding:
  300. # if padding is specified, we leave it up to the user to ensure it fits
  301. # otherwise we check here if it will fit and offer to calculate padding
  302. if not _check_patch_fit(original_size, window_size, stride):
  303. if not allow_auto_unpadding:
  304. warn(
  305. f"The window will not fit into the image. \nWindow size: {window_size}\nStride: {stride}\n"
  306. f"Image size: {original_size}\n"
  307. "This means we probably cannot correctly recombine patches. By enabling `allow_auto_unpadding`, "
  308. "the input will be unpadded to fit the window and stride.\n"
  309. "If the patches have been obtained through `extract_tensor_patches` with the correct padding or "
  310. "the argument `allow_auto_padding`, this will result in a correct reconstruction.",
  311. stacklevel=1,
  312. )
  313. else:
  314. unpadding = compute_padding(original_size=original_size, window_size=window_size, stride=stride)
  315. # TODO: Can't we just do actual size minus original size to get padding?
  316. if unpadding:
  317. unpadding = create_padding_tuple(unpadding)
  318. ones = torch.ones(
  319. patches.shape[0],
  320. patches.shape[2],
  321. original_size[0],
  322. original_size[1],
  323. device=patches.device,
  324. dtype=patches.dtype,
  325. )
  326. if unpadding:
  327. ones = pad(ones, pad=unpadding)
  328. restored_size = ones.shape[2:]
  329. patches = patches.permute(0, 2, 3, 4, 1)
  330. patches = patches.reshape(patches.shape[0], -1, patches.shape[-1])
  331. int_flag = 0
  332. if not torch.is_floating_point(patches):
  333. int_flag = 1
  334. dtype = patches.dtype
  335. patches = patches.float()
  336. ones = ones.float()
  337. # Calculate normalization map
  338. unfold_ones = F.unfold(ones, kernel_size=window_size, stride=stride)
  339. norm_map = F.fold(input=unfold_ones, output_size=restored_size, kernel_size=window_size, stride=stride)
  340. if unpadding:
  341. norm_map = pad(norm_map, [-i for i in unpadding])
  342. # Restored tensor
  343. saturated_restored_tensor = F.fold(input=patches, output_size=restored_size, kernel_size=window_size, stride=stride)
  344. if unpadding:
  345. saturated_restored_tensor = pad(saturated_restored_tensor, [-i for i in unpadding])
  346. # Remove satuation effect due to multiple summations
  347. restored_tensor = saturated_restored_tensor / (norm_map + eps)
  348. if int_flag:
  349. restored_tensor = restored_tensor.to(dtype)
  350. return restored_tensor
  351. def _extract_tensor_patchesnd(input: Tensor, window_sizes: Tuple[int, ...], strides: Tuple[int, ...]) -> Tensor:
  352. batch_size, num_channels = input.size()[:2]
  353. dims = range(2, input.dim())
  354. for dim, patch_size, stride in zip(dims, window_sizes, strides):
  355. input = input.unfold(dim, patch_size, stride)
  356. input = input.permute(0, *dims, 1, *(dim + len(dims) for dim in dims)).contiguous()
  357. return input.view(batch_size, -1, num_channels, *window_sizes)
  358. def extract_tensor_patches(
  359. input: Tensor,
  360. window_size: Union[int, Tuple[int, int]],
  361. stride: Union[int, Tuple[int, int]] = 1,
  362. padding: PadType = 0,
  363. allow_auto_padding: bool = False,
  364. ) -> Tensor:
  365. r"""Extract patches from tensors and stacks them.
  366. See :class:`~kornia.contrib.ExtractTensorPatches` for details.
  367. Args:
  368. input: tensor image where to extract the patches with shape :math:`(B, C, H, W)`.
  369. window_size: the size of the sliding window and the output patch size.
  370. stride: stride of the sliding window.
  371. padding: Zero-padding added to both side of the input.
  372. allow_auto_padding: whether to allow automatic padding if the window and stride do not fit into the image.
  373. Returns:
  374. the tensor with the extracted patches with shape :math:`(B, N, C, H_{out}, W_{out})`.
  375. Examples:
  376. >>> input = torch.arange(9.).view(1, 1, 3, 3)
  377. >>> patches = extract_tensor_patches(input, (2, 3))
  378. >>> input
  379. tensor([[[[0., 1., 2.],
  380. [3., 4., 5.],
  381. [6., 7., 8.]]]])
  382. >>> patches[:, -1]
  383. tensor([[[[3., 4., 5.],
  384. [6., 7., 8.]]]])
  385. """
  386. if not torch.is_tensor(input):
  387. raise TypeError(f"Input input type is not a Tensor. Got {type(input)}")
  388. if len(input.shape) != 4:
  389. raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
  390. # check if the window sliding over the image will fit into the image
  391. # torch's unfold drops the final patches that don't fit
  392. window_size = cast(Tuple[int, int], _pair(window_size))
  393. stride = cast(Tuple[int, int], _pair(stride))
  394. original_size = (input.shape[-2], input.shape[-1])
  395. if not padding:
  396. # if padding is specified, we leave it up to the user to ensure it fits
  397. # otherwise we check here if it will fit and offer to calculate padding
  398. if not _check_patch_fit(original_size, window_size, stride):
  399. if not allow_auto_padding:
  400. warn(
  401. f"The window will not fit into the image. \nWindow size: {window_size}\nStride: {stride}\n"
  402. f"Image size: {original_size}\n"
  403. "This means that the final incomplete patches will be dropped. By enabling `allow_auto_padding`, "
  404. "the input will be padded to fit the window and stride.",
  405. stacklevel=1,
  406. )
  407. else:
  408. padding = compute_padding(original_size=original_size, window_size=window_size, stride=stride)
  409. if padding:
  410. padding = create_padding_tuple(padding)
  411. input = pad(input, padding)
  412. return _extract_tensor_patchesnd(input, window_size, stride)