patch.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from itertools import cycle, islice
  18. from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
  19. import torch
  20. import kornia.augmentation as K
  21. from kornia.augmentation.base import _AugmentationBase
  22. from kornia.contrib.extract_patches import extract_tensor_patches
  23. from kornia.core import Module, Tensor, concatenate
  24. from kornia.core import pad as fpad
  25. from kornia.geometry.boxes import Boxes
  26. from kornia.geometry.keypoints import Keypoints
  27. from .base import SequentialBase
  28. from .image import ImageSequential
  29. from .ops import InputSequentialOps
  30. from .params import ParamItem, PatchParamItem
  31. __all__ = ["PatchSequential"]
  32. class PatchSequential(ImageSequential):
  33. r"""Container for performing patch-level image data augmentation.
  34. .. image:: _static/img/PatchSequential.png
  35. PatchSequential breaks input images into patches by a given grid size, which will be resembled back
  36. afterwards.
  37. Different image processing and augmentation methods will be performed on each patch region as
  38. in :cite:`lin2021patch`.
  39. Args:
  40. *args: a list of processing modules.
  41. grid_size: controls the grid board separation.
  42. padding: same or valid padding. If same padding, it will pad to include all pixels if the input
  43. tensor cannot be divisible by grid_size. If valid padding, the redundant border will be removed.
  44. same_on_batch: apply the same transformation across the batch.
  45. If None, it will not overwrite the function-wise settings.
  46. keepdim: whether to keep the output shape the same as input (True) or broadcast it
  47. to the batch form (False). If None, it will not overwrite the function-wise settings.
  48. patchwise_apply: apply image processing args will be applied patch-wisely.
  49. if ``True``, the number of args must be equal to grid number.
  50. if ``False``, the image processing args will be applied as a sequence to all patches.
  51. random_apply: randomly select a sublist (order agnostic) of args to
  52. apply transformation.
  53. If ``int`` (batchwise mode only), a fixed number of transformations will be selected.
  54. If ``(a,)`` (batchwise mode only), x number of transformations (a <= x <= len(args)) will be selected.
  55. If ``(a, b)`` (batchwise mode only), x number of transformations (a <= x <= b) will be selected.
  56. If ``True``, the whole list of args will be processed in a random order.
  57. If ``False`` and not ``patchwise_apply``, the whole list of args will be processed in original order.
  58. If ``False`` and ``patchwise_apply``, the whole list of args will be processed in original order
  59. location-wisely.
  60. .. note::
  61. Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module.
  62. Those transformations in ``kornia.geometry`` will not be taken into account.
  63. .. note::
  64. See a working example `here <https://kornia.github.io/tutorials/nbs/data_patch_sequential.html>`__.
  65. Examples:
  66. >>> import kornia.augmentation as K
  67. >>> input = torch.randn(2, 3, 224, 224)
  68. >>> seq = PatchSequential(
  69. ... ImageSequential(
  70. ... K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.5),
  71. ... K.RandomPerspective(0.2, p=0.5),
  72. ... K.RandomSolarize(0.1, 0.1, p=0.5),
  73. ... ),
  74. ... K.RandomAffine(360, p=1.0),
  75. ... ImageSequential(
  76. ... K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.5),
  77. ... K.RandomPerspective(0.2, p=0.5),
  78. ... K.RandomSolarize(0.1, 0.1, p=0.5),
  79. ... ),
  80. ... K.RandomSolarize(0.1, 0.1, p=0.1),
  81. ... grid_size=(2,2),
  82. ... patchwise_apply=True,
  83. ... same_on_batch=True,
  84. ... random_apply=False,
  85. ... )
  86. >>> out = seq(input)
  87. >>> out.shape
  88. torch.Size([2, 3, 224, 224])
  89. >>> out1 = seq(input, params=seq._params)
  90. >>> torch.equal(out, out1)
  91. True
  92. Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights`` in ``PatchSequential``.
  93. >>> import kornia
  94. >>> input = torch.randn(2, 3, 224, 224)
  95. >>> seq = PatchSequential(
  96. ... ImageSequential(
  97. ... K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.5),
  98. ... K.RandomPerspective(0.2, p=0.5),
  99. ... K.RandomSolarize(0.1, 0.1, p=0.5),
  100. ... ),
  101. ... K.RandomAffine(360, p=1.0),
  102. ... K.RandomSolarize(0.1, 0.1, p=0.1),
  103. ... grid_size=(2,2),
  104. ... patchwise_apply=False,
  105. ... random_apply=1,
  106. ... random_apply_weights=[0.5, 0.3, 0.8]
  107. ... )
  108. >>> out = seq(input)
  109. >>> out.shape
  110. torch.Size([2, 3, 224, 224])
  111. """
  112. def __init__(
  113. self,
  114. *args: Module,
  115. grid_size: Tuple[int, int] = (4, 4),
  116. padding: str = "same",
  117. same_on_batch: Optional[bool] = None,
  118. keepdim: Optional[bool] = None,
  119. patchwise_apply: bool = True,
  120. random_apply: Union[int, bool, Tuple[int, int]] = False,
  121. random_apply_weights: Optional[List[float]] = None,
  122. ) -> None:
  123. _random_apply: Optional[Union[int, Tuple[int, int]]]
  124. if patchwise_apply and random_apply is True:
  125. # will only apply [1, 4] augmentations per patch
  126. _random_apply = (1, 4)
  127. elif patchwise_apply and random_apply is False:
  128. if len(args) != grid_size[0] * grid_size[1]:
  129. raise ValueError(
  130. "The number of processing modules must be equal with grid size."
  131. f"Got {len(args)} and {grid_size[0] * grid_size[1]}. "
  132. "Please set random_apply = True or patchwise_apply = False."
  133. )
  134. _random_apply = random_apply
  135. elif patchwise_apply and isinstance(random_apply, (int, tuple)):
  136. raise ValueError(f"Only boolean value allowed when `patchwise_apply` is set to True. Got {random_apply}.")
  137. else:
  138. _random_apply = random_apply
  139. super().__init__(
  140. *args,
  141. same_on_batch=same_on_batch,
  142. keepdim=keepdim,
  143. random_apply=_random_apply,
  144. random_apply_weights=random_apply_weights,
  145. )
  146. if padding not in ("same", "valid"):
  147. raise ValueError(f"`padding` must be either `same` or `valid`. Got {padding}.")
  148. self.grid_size = grid_size
  149. self.padding = padding
  150. self.patchwise_apply = patchwise_apply
  151. self._params: Optional[List[PatchParamItem]] # type: ignore[assignment]
  152. def compute_padding(
  153. self, input: Tensor, padding: str, grid_size: Optional[Tuple[int, int]] = None
  154. ) -> Tuple[int, int, int, int]:
  155. if grid_size is None:
  156. grid_size = self.grid_size
  157. if padding == "valid":
  158. ph, pw = input.size(-2) // grid_size[0], input.size(-1) // grid_size[1]
  159. return (-pw // 2, pw // 2 - pw, -ph // 2, ph // 2 - ph)
  160. if padding == "same":
  161. ph = input.size(-2) - input.size(-2) // grid_size[0] * grid_size[0]
  162. pw = input.size(-1) - input.size(-1) // grid_size[1] * grid_size[1]
  163. return (pw // 2, pw - pw // 2, ph // 2, ph - ph // 2)
  164. raise NotImplementedError(f"Expect `padding` as either 'valid' or 'same'. Got {padding}.")
  165. def extract_patches(
  166. self,
  167. input: Tensor,
  168. grid_size: Optional[Tuple[int, int]] = None,
  169. pad: Optional[Tuple[int, int, int, int]] = None,
  170. ) -> Tensor:
  171. """Extract patches from tensor.
  172. Example:
  173. >>> import kornia.augmentation as K
  174. >>> pas = PatchSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), patchwise_apply=False)
  175. >>> pas.extract_patches(torch.arange(16).view(1, 1, 4, 4), grid_size=(2, 2))
  176. tensor([[[[[ 0, 1],
  177. [ 4, 5]]],
  178. <BLANKLINE>
  179. <BLANKLINE>
  180. [[[ 2, 3],
  181. [ 6, 7]]],
  182. <BLANKLINE>
  183. <BLANKLINE>
  184. [[[ 8, 9],
  185. [12, 13]]],
  186. <BLANKLINE>
  187. <BLANKLINE>
  188. [[[10, 11],
  189. [14, 15]]]]])
  190. >>> pas.extract_patches(torch.arange(54).view(1, 1, 6, 9), grid_size=(2, 2), pad=(-1, -1, -2, -2))
  191. tensor([[[[[19, 20, 21]]],
  192. <BLANKLINE>
  193. <BLANKLINE>
  194. [[[22, 23, 24]]],
  195. <BLANKLINE>
  196. <BLANKLINE>
  197. [[[28, 29, 30]]],
  198. <BLANKLINE>
  199. <BLANKLINE>
  200. [[[31, 32, 33]]]]])
  201. """
  202. if pad is not None:
  203. input = fpad(input, list(pad))
  204. if grid_size is None:
  205. grid_size = self.grid_size
  206. window_size = (input.size(-2) // grid_size[-2], input.size(-1) // grid_size[-1])
  207. stride = window_size
  208. return extract_tensor_patches(input, window_size, stride)
  209. def restore_from_patches(
  210. self, patches: Tensor, grid_size: Tuple[int, int] = (4, 4), pad: Optional[Tuple[int, int, int, int]] = None
  211. ) -> Tensor:
  212. """Restore input from patches.
  213. Example:
  214. >>> import kornia.augmentation as K
  215. >>> pas = PatchSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), patchwise_apply=False)
  216. >>> out = pas.extract_patches(torch.arange(16).view(1, 1, 4, 4), grid_size=(2, 2))
  217. >>> pas.restore_from_patches(out, grid_size=(2, 2))
  218. tensor([[[[ 0, 1, 2, 3],
  219. [ 4, 5, 6, 7],
  220. [ 8, 9, 10, 11],
  221. [12, 13, 14, 15]]]])
  222. """
  223. if grid_size is None:
  224. grid_size = self.grid_size
  225. patches_tensor = patches.view(-1, grid_size[0], grid_size[1], *patches.shape[-3:])
  226. restored_tensor = concatenate(torch.chunk(patches_tensor, grid_size[0], 1), -2).squeeze(1)
  227. restored_tensor = concatenate(torch.chunk(restored_tensor, grid_size[1], 1), -1).squeeze(1)
  228. if pad is not None:
  229. restored_tensor = fpad(restored_tensor, [-i for i in pad])
  230. return restored_tensor
  231. def forward_parameters(self, batch_shape: torch.Size) -> List[PatchParamItem]: # type: ignore[override]
  232. out_param: List[PatchParamItem] = []
  233. if not self.patchwise_apply:
  234. params = self.generate_parameters(torch.Size([1, batch_shape[0] * batch_shape[1], *batch_shape[2:]]))
  235. indices = torch.arange(0, batch_shape[0] * batch_shape[1])
  236. out_param = [PatchParamItem(indices.tolist(), p) for p, _ in params]
  237. # "append" of "list" does not return a value
  238. elif not self.same_on_batch:
  239. params = self.generate_parameters(torch.Size([batch_shape[0] * batch_shape[1], 1, *batch_shape[2:]]))
  240. out_param = [PatchParamItem([i], p) for p, i in params]
  241. # "append" of "list" does not return a value
  242. else:
  243. params = self.generate_parameters(torch.Size([batch_shape[1], batch_shape[0], *batch_shape[2:]]))
  244. indices = torch.arange(0, batch_shape[0] * batch_shape[1], step=batch_shape[1])
  245. out_param = [PatchParamItem((indices + i).tolist(), p) for p, i in params]
  246. # "append" of "list" does not return a value
  247. return out_param
  248. def generate_parameters(self, batch_shape: torch.Size) -> Iterator[Tuple[ParamItem, int]]:
  249. """Get multiple forward sequence but maximumly one mix augmentation in between.
  250. Args:
  251. batch_shape: 5-dim shape arranged as :math:``(N, B, C, H, W)``, in which N represents
  252. the number of sequence.
  253. """
  254. if not self.same_on_batch and self.random_apply:
  255. # diff_on_batch and random_apply => patch-wise augmentation
  256. with_mix = False
  257. for i in range(batch_shape[0]):
  258. seq, mix_added = self.get_random_forward_sequence(with_mix=with_mix)
  259. with_mix = mix_added
  260. for s in seq:
  261. if isinstance(s[1], (_AugmentationBase, SequentialBase, K.MixAugmentationBaseV2)):
  262. yield ParamItem(s[0], s[1].forward_parameters(torch.Size(batch_shape[1:]))), i
  263. else:
  264. yield ParamItem(s[0], None), i
  265. elif not self.same_on_batch and not self.random_apply:
  266. for i, nchild in enumerate(self.named_children()):
  267. if isinstance(nchild[1], (_AugmentationBase, SequentialBase, K.MixAugmentationBaseV2)):
  268. yield ParamItem(nchild[0], nchild[1].forward_parameters(torch.Size(batch_shape[1:]))), i
  269. else:
  270. yield ParamItem(nchild[0], None), i
  271. elif not self.random_apply:
  272. # same_on_batch + not random_apply => location-wise augmentation
  273. for i, nchild in enumerate(islice(cycle(self.named_children()), batch_shape[0])):
  274. if isinstance(nchild[1], (_AugmentationBase, SequentialBase, K.MixAugmentationBaseV2)):
  275. yield ParamItem(nchild[0], nchild[1].forward_parameters(torch.Size(batch_shape[1:]))), i
  276. else:
  277. yield ParamItem(nchild[0], None), i
  278. else:
  279. # same_on_batch + random_apply => location-wise augmentation
  280. with_mix = False
  281. for i in range(batch_shape[0]):
  282. seq, mix_added = self.get_random_forward_sequence(with_mix=with_mix)
  283. with_mix = mix_added
  284. for s in seq:
  285. if isinstance(s[1], (_AugmentationBase, SequentialBase, K.MixAugmentationBaseV2)):
  286. yield ParamItem(s[0], s[1].forward_parameters(torch.Size(batch_shape[1:]))), i
  287. else:
  288. yield ParamItem(s[0], None), i
  289. def forward_by_params(self, input: Tensor, params: List[PatchParamItem]) -> Tensor:
  290. in_shape = input.shape
  291. input = input.reshape(-1, *in_shape[-3:])
  292. for patch_param in params:
  293. # input, out_param = self.apply_by_param(input, params=patch_param)
  294. module = self.get_submodule(patch_param.param.name)
  295. _input = input[patch_param.indices]
  296. output = InputSequentialOps.transform(_input, module, patch_param.param, extra_args={})
  297. input[patch_param.indices] = output
  298. return input.reshape(in_shape)
  299. def transform_inputs( # type: ignore[override]
  300. self, input: Tensor, params: List[PatchParamItem], extra_args: Optional[Dict[str, Any]] = None
  301. ) -> Tensor:
  302. pad = self.compute_padding(input, self.padding)
  303. input = self.extract_patches(input, self.grid_size, pad)
  304. input = self.forward_by_params(input, params)
  305. input = self.restore_from_patches(input, self.grid_size, pad=pad)
  306. return input
  307. def inverse_inputs( # type: ignore[override]
  308. self, input: Tensor, params: List[PatchParamItem], extra_args: Optional[Dict[str, Any]] = None
  309. ) -> Tensor:
  310. if self.is_intensity_only():
  311. return input
  312. raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.")
  313. def transform_masks( # type: ignore[override]
  314. self, input: Tensor, params: List[PatchParamItem], extra_args: Optional[Dict[str, Any]] = None
  315. ) -> Tensor:
  316. if self.is_intensity_only():
  317. return input
  318. raise NotImplementedError("PatchSequential for boxes cannot be used with geometric transformations.")
  319. def inverse_masks( # type: ignore[override]
  320. self, input: Tensor, params: List[PatchParamItem], extra_args: Optional[Dict[str, Any]] = None
  321. ) -> Tensor:
  322. if self.is_intensity_only():
  323. return input
  324. raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.")
  325. def transform_boxes( # type: ignore[override]
  326. self, input: Boxes, params: List[PatchParamItem], extra_args: Optional[Dict[str, Any]] = None
  327. ) -> Boxes:
  328. if self.is_intensity_only():
  329. return input
  330. raise NotImplementedError("PatchSequential for boxes cannot be used with geometric transformations.")
  331. def inverse_boxes( # type: ignore[override]
  332. self, input: Boxes, params: List[PatchParamItem], extra_args: Optional[Dict[str, Any]] = None
  333. ) -> Boxes:
  334. if self.is_intensity_only():
  335. return input
  336. raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.")
  337. def transform_keypoints( # type: ignore[override]
  338. self, input: Keypoints, params: List[PatchParamItem], extra_args: Optional[Dict[str, Any]] = None
  339. ) -> Keypoints:
  340. if self.is_intensity_only():
  341. return input
  342. raise NotImplementedError("PatchSequential for keypoints cannot be used with geometric transformations.")
  343. def inverse_keypoints( # type: ignore[override]
  344. self, input: Keypoints, params: List[PatchParamItem], extra_args: Optional[Dict[str, Any]] = None
  345. ) -> Keypoints:
  346. if self.is_intensity_only():
  347. return input
  348. raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.")
  349. def inverse( # type: ignore[override]
  350. self, input: Tensor, params: Optional[List[PatchParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
  351. ) -> Tensor:
  352. """Inverse transformation.
  353. Used to inverse a tensor according to the performed transformation by a forward pass, or with respect to
  354. provided parameters.
  355. """
  356. if self.is_intensity_only():
  357. return input
  358. raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.")
  359. def forward(self, input: Tensor, params: Optional[List[PatchParamItem]] = None) -> Tensor: # type: ignore[override]
  360. """Input transformation will be returned if input is a tuple."""
  361. # BCHW -> B(patch)CHW
  362. if isinstance(input, (tuple,)):
  363. raise ValueError("tuple input is not currently supported.")
  364. if params is None:
  365. params = self.forward_parameters(input.shape)
  366. output = self.transform_inputs(input, params=params)
  367. self._params = params
  368. return output