base.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from collections import OrderedDict
  18. from itertools import zip_longest
  19. from typing import Any, Dict, Iterator, List, Optional, Tuple
  20. import torch
  21. from torch import nn
  22. import kornia.augmentation as K
  23. from kornia.augmentation.base import _AugmentationBase
  24. from kornia.core import Module, Tensor
  25. from kornia.geometry.boxes import Boxes
  26. from kornia.geometry.keypoints import Keypoints
  27. from .ops import BoxSequentialOps, InputSequentialOps, KeypointSequentialOps, MaskSequentialOps
  28. from .params import ParamItem
  29. __all__ = ["BasicSequentialBase", "ImageSequentialBase", "SequentialBase"]
  30. class BasicSequentialBase(nn.Sequential):
  31. r"""BasicSequential for creating kornia modulized processing pipeline.
  32. Args:
  33. *args : a list of kornia augmentation and image operation modules.
  34. """
  35. def __init__(self, *args: Module) -> None:
  36. # To name the modules properly
  37. _args = OrderedDict()
  38. for idx, mod in enumerate(args):
  39. if not isinstance(mod, Module):
  40. raise NotImplementedError(f"Only Module are supported at this moment. Got {mod}.")
  41. _args.update({f"{mod.__class__.__name__}_{idx}": mod})
  42. super().__init__(_args)
  43. self._params: Optional[List[ParamItem]] = None
  44. def get_submodule(self, target: str) -> Module:
  45. """Get submodule.
  46. This code is taken from torch 1.9.0 since it is not introduced
  47. back to torch 1.7.1. We included this for maintaining more
  48. backward torch versions.
  49. Args:
  50. target: The fully-qualified string name of the submodule
  51. to look for. (See above example for how to specify a
  52. fully-qualified string.)
  53. Returns:
  54. Module: The submodule referenced by ``target``
  55. Raises:
  56. AttributeError: If the target string references an invalid
  57. path or resolves to something that is not an
  58. ``Module``
  59. """
  60. if len(target) == 0:
  61. return self
  62. atoms: List[str] = target.split(".")
  63. mod = self
  64. for item in atoms:
  65. if not hasattr(mod, item):
  66. raise AttributeError(mod._get_name() + " has no attribute `" + item + "`")
  67. mod = getattr(mod, item)
  68. if not isinstance(mod, Module):
  69. raise AttributeError("`" + item + "` is not an Module")
  70. return mod
  71. def clear_state(self) -> None:
  72. """Reset self._params state to None."""
  73. self._params = None
  74. # TODO: Implement this for all submodules.
  75. def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
  76. raise NotImplementedError
  77. def get_children_by_indices(self, indices: Tensor) -> Iterator[Tuple[str, Module]]:
  78. modules = list(self.named_children())
  79. for idx in indices:
  80. yield modules[idx]
  81. def get_children_by_params(self, params: List[ParamItem]) -> Iterator[Tuple[str, Module]]:
  82. modules = list(self.named_children())
  83. # TODO: Wrong params passed here when nested ImageSequential
  84. for param in params:
  85. yield modules[list(dict(self.named_children()).keys()).index(param.name)]
  86. def get_params_by_module(self, named_modules: Iterator[Tuple[str, Module]]) -> Iterator[ParamItem]:
  87. # This will not take module._params
  88. for name, _ in named_modules:
  89. yield ParamItem(name, None)
  90. class SequentialBase(BasicSequentialBase):
  91. r"""SequentialBase for creating kornia modulized processing pipeline.
  92. Args:
  93. *args : a list of kornia augmentation and image operation modules.
  94. same_on_batch: apply the same transformation across the batch.
  95. If None, it will not overwrite the function-wise settings.
  96. return_transform: if ``True`` return the matrix describing the transformation
  97. applied to each. If None, it will not overwrite the function-wise settings.
  98. keepdim: whether to keep the output shape the same as input (True) or broadcast it
  99. to the batch form (False). If None, it will not overwrite the function-wise settings.
  100. """
  101. def __init__(self, *args: Module, same_on_batch: Optional[bool] = None, keepdim: Optional[bool] = None) -> None:
  102. # To name the modules properly
  103. super().__init__(*args)
  104. self._same_on_batch = same_on_batch
  105. self._keepdim = keepdim
  106. self.update_attribute(same_on_batch, keepdim=keepdim)
  107. def update_attribute(
  108. self,
  109. same_on_batch: Optional[bool] = None,
  110. return_transform: Optional[bool] = None,
  111. keepdim: Optional[bool] = None,
  112. ) -> None:
  113. for mod in self.children():
  114. # MixAugmentation does not have return transform
  115. if isinstance(mod, (_AugmentationBase, K.MixAugmentationBaseV2)):
  116. if same_on_batch is not None:
  117. mod.same_on_batch = same_on_batch
  118. if keepdim is not None:
  119. mod.keepdim = keepdim
  120. if isinstance(mod, SequentialBase):
  121. mod.update_attribute(same_on_batch, return_transform, keepdim)
  122. @property
  123. def same_on_batch(self) -> Optional[bool]:
  124. return self._same_on_batch
  125. @same_on_batch.setter
  126. def same_on_batch(self, same_on_batch: Optional[bool]) -> None:
  127. self._same_on_batch = same_on_batch
  128. self.update_attribute(same_on_batch=same_on_batch)
  129. @property
  130. def keepdim(self) -> Optional[bool]:
  131. return self._keepdim
  132. @keepdim.setter
  133. def keepdim(self, keepdim: Optional[bool]) -> None:
  134. self._keepdim = keepdim
  135. self.update_attribute(keepdim=keepdim)
  136. def autofill_dim(self, input: Tensor, dim_range: Tuple[int, int] = (2, 4)) -> Tuple[torch.Size, torch.Size]:
  137. """Fill tensor dim to the upper bound of dim_range.
  138. If input tensor dim is smaller than the lower bound of dim_range, an error will be thrown out.
  139. """
  140. ori_shape = input.shape
  141. if len(ori_shape) < dim_range[0] or len(ori_shape) > dim_range[1]:
  142. raise RuntimeError(f"input shape expected to be in {dim_range} while got {ori_shape}.")
  143. while len(input.shape) < dim_range[1]:
  144. input = input[None]
  145. return ori_shape, input.shape
  146. class ImageSequentialBase(SequentialBase):
  147. def identity_matrix(self, input: Tensor) -> Tensor:
  148. """Return identity matrix."""
  149. raise NotImplementedError
  150. def get_transformation_matrix(
  151. self,
  152. input: Tensor,
  153. params: Optional[List[ParamItem]] = None,
  154. recompute: bool = False,
  155. extra_args: Optional[Dict[str, Any]] = None,
  156. ) -> Optional[Tensor]:
  157. """Compute the transformation matrix according to the provided parameters.
  158. Args:
  159. input: the input tensor.
  160. params: params for the sequence.
  161. recompute: if to recompute the transformation matrix according to the params.
  162. default: False.
  163. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  164. """
  165. raise NotImplementedError
  166. def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
  167. raise NotImplementedError
  168. def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
  169. """Get module sequence by input params."""
  170. raise NotImplementedError
  171. def transform_inputs(
  172. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  173. ) -> Tensor:
  174. for param in params:
  175. module = self.get_submodule(param.name)
  176. input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
  177. return input
  178. def inverse_inputs(
  179. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  180. ) -> Tensor:
  181. for (_, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]):
  182. input = InputSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args)
  183. return input
  184. def transform_masks(
  185. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  186. ) -> Tensor:
  187. for param in params:
  188. module = self.get_submodule(param.name)
  189. input = MaskSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
  190. return input
  191. def inverse_masks(
  192. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  193. ) -> Tensor:
  194. for (_, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]):
  195. input = MaskSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args)
  196. return input
  197. def transform_boxes(
  198. self, input: Boxes, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  199. ) -> Boxes:
  200. for param in params:
  201. module = self.get_submodule(param.name)
  202. input = BoxSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
  203. return input
  204. def inverse_boxes(
  205. self, input: Boxes, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  206. ) -> Boxes:
  207. for (_, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]):
  208. input = BoxSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args)
  209. return input
  210. def transform_keypoints(
  211. self, input: Keypoints, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  212. ) -> Keypoints:
  213. for param in params:
  214. module = self.get_submodule(param.name)
  215. input = KeypointSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
  216. return input
  217. def inverse_keypoints(
  218. self, input: Keypoints, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  219. ) -> Keypoints:
  220. for (_, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]):
  221. input = KeypointSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args)
  222. return input
  223. def inverse(
  224. self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
  225. ) -> Tensor:
  226. """Inverse transformation.
  227. Used to inverse a tensor according to the performed transformation by a forward pass, or with respect to
  228. provided parameters.
  229. """
  230. if params is None:
  231. if self._params is None:
  232. raise ValueError(
  233. "No parameters available for inversing, please run a forward pass first "
  234. "or passing valid params into this function."
  235. )
  236. params = self._params
  237. input = self.inverse_inputs(input, params, extra_args=extra_args)
  238. return input
  239. def forward(
  240. self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
  241. ) -> Tensor:
  242. self.clear_state()
  243. if params is None:
  244. inp = input
  245. _, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
  246. params = self.forward_parameters(out_shape)
  247. input = self.transform_inputs(input, params=params, extra_args=extra_args)
  248. self._params = params
  249. return input
  250. class TransformMatrixMinIn:
  251. """Enables computation matrix computation."""
  252. _valid_ops_for_transform_computation: Tuple[Any, ...] = ()
  253. _transformation_matrix_arg: str = "silent"
  254. def __init__(self, *args, **kwargs) -> None: # type:ignore
  255. super().__init__(*args, **kwargs)
  256. self._transform_matrix: Optional[Tensor] = None
  257. self._transform_matrices: List[Optional[Tensor]] = []
  258. def _parse_transformation_matrix_mode(self, transformation_matrix_mode: str) -> None:
  259. _valid_transformation_matrix_args = {"silence", "silent", "rigid", "skip"}
  260. if transformation_matrix_mode not in _valid_transformation_matrix_args:
  261. raise ValueError(
  262. f"`transformation_matrix` has to be one of {_valid_transformation_matrix_args}. "
  263. f"Got {transformation_matrix_mode}."
  264. )
  265. self._transformation_matrix_arg = transformation_matrix_mode
  266. @property
  267. def transform_matrix(self) -> Optional[Tensor]:
  268. # In AugmentationSequential, the parent class is accessed first.
  269. # So that it was None in the beginning. We hereby use lazy computation here.
  270. if self._transform_matrix is None and len(self._transform_matrices) != 0:
  271. self._transform_matrix = self._transform_matrices[0]
  272. for mat in self._transform_matrices[1:]:
  273. self._update_transform_matrix(mat)
  274. return self._transform_matrix
  275. def _update_transform_matrix_for_valid_op(self, module: Module) -> None:
  276. raise NotImplementedError(module)
  277. def _update_transform_matrix_by_module(self, module: Module) -> None:
  278. if self._transformation_matrix_arg == "skip":
  279. return
  280. if isinstance(module, self._valid_ops_for_transform_computation):
  281. self._update_transform_matrix_for_valid_op(module)
  282. elif self._transformation_matrix_arg == "rigid":
  283. raise RuntimeError(
  284. f"Non-rigid module `{module}` is not supported under `rigid` computation mode. "
  285. "Please either update the module or change the `transformation_matrix` argument."
  286. )
  287. def _update_transform_matrix(self, transform_matrix: Optional[Tensor]) -> None:
  288. if self._transform_matrix is None:
  289. self._transform_matrix = transform_matrix
  290. else:
  291. self._transform_matrix = transform_matrix @ self._transform_matrix
  292. def _reset_transform_matrix_state(self) -> None:
  293. self._transform_matrix = None
  294. self._transform_matrices = []