image.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
  18. import torch
  19. import kornia.augmentation as K
  20. from kornia.augmentation.base import _AugmentationBase
  21. from kornia.augmentation.utils import override_parameters
  22. from kornia.core import ImageModule, Module, Tensor, as_tensor
  23. from kornia.core.module import ImageModuleMixIn
  24. from kornia.utils import eye_like
  25. from .base import ImageSequentialBase
  26. from .params import ParamItem
  27. __all__ = ["ImageSequential"]
  28. class ImageModuleForSequentialMixIn(ImageModuleMixIn):
  29. _disable_features: bool = False
  30. @property
  31. def disable_features(self) -> bool:
  32. return self._disable_features
  33. @disable_features.setter
  34. def disable_features(self, value: bool = True) -> None:
  35. self._disable_features = value
  36. def disable_item_features(self, *args: Module) -> None:
  37. for arg in args:
  38. if isinstance(arg, (ImageModule,)):
  39. arg.disable_features = True
  40. class ImageSequential(ImageSequentialBase, ImageModuleForSequentialMixIn):
  41. r"""Sequential for creating kornia image processing pipeline.
  42. Args:
  43. *args : a list of kornia augmentation and image operation modules.
  44. same_on_batch: apply the same transformation across the batch.
  45. If None, it will not overwrite the function-wise settings.
  46. keepdim: whether to keep the output shape the same as input (True) or broadcast it
  47. to the batch form (False). If None, it will not overwrite the function-wise settings.
  48. random_apply: randomly select a sublist (order agnostic) of args to
  49. apply transformation. The selection probability aligns to the ``random_apply_weights``.
  50. If int, a fixed number of transformations will be selected.
  51. If (a,), x number of transformations (a <= x <= len(args)) will be selected.
  52. If (a, b), x number of transformations (a <= x <= b) will be selected.
  53. If True, the whole list of args will be processed as a sequence in a random order.
  54. If False, the whole list of args will be processed as a sequence in original order.
  55. random_apply_weights: a list of selection weights for each operation. The length shall be as
  56. same as the number of operations. By default, operations are sampled uniformly.
  57. .. note::
  58. Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module.
  59. Those transformations in ``kornia.geometry`` will not be taken into account.
  60. Examples:
  61. >>> _ = torch.manual_seed(77)
  62. >>> import kornia
  63. >>> input = torch.randn(2, 3, 5, 6)
  64. >>> aug_list = ImageSequential(
  65. ... kornia.color.BgrToRgb(),
  66. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  67. ... kornia.filters.MedianBlur((3, 3)),
  68. ... kornia.augmentation.RandomAffine(360, p=1.0),
  69. ... kornia.enhance.Invert(),
  70. ... kornia.augmentation.RandomMixUpV2(p=1.0),
  71. ... same_on_batch=True,
  72. ... random_apply=10,
  73. ... )
  74. >>> out = aug_list(input)
  75. >>> out.shape
  76. torch.Size([2, 3, 5, 6])
  77. Reproduce with provided params.
  78. >>> out2 = aug_list(input, params=aug_list._params)
  79. >>> torch.equal(out, out2)
  80. True
  81. Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights`` in ``ImageSequential``.
  82. >>> import kornia
  83. >>> input = torch.randn(2, 3, 5, 6)
  84. >>> aug_list = ImageSequential(
  85. ... kornia.color.BgrToRgb(),
  86. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  87. ... kornia.filters.MedianBlur((3, 3)),
  88. ... kornia.augmentation.RandomAffine(360, p=1.0),
  89. ... random_apply=1,
  90. ... random_apply_weights=[0.5, 0.3, 0.2, 0.5]
  91. ... )
  92. >>> out= aug_list(input)
  93. >>> out.shape
  94. torch.Size([2, 3, 5, 6])
  95. """
  96. def __init__(
  97. self,
  98. *args: Module,
  99. same_on_batch: Optional[bool] = None,
  100. keepdim: Optional[bool] = None,
  101. random_apply: Union[int, bool, Tuple[int, int]] = False,
  102. random_apply_weights: Optional[List[float]] = None,
  103. if_unsupported_ops: str = "raise",
  104. disable_item_features: bool = True,
  105. disable_sequential_features: bool = False,
  106. ) -> None:
  107. if disable_item_features:
  108. self.disable_item_features(*args)
  109. if disable_sequential_features:
  110. self.disable_features = True
  111. super().__init__(*args, same_on_batch=same_on_batch, keepdim=keepdim)
  112. self.random_apply = self._read_random_apply(random_apply, len(args))
  113. if random_apply_weights is not None and len(random_apply_weights) != len(self):
  114. raise ValueError(
  115. "The length of `random_apply_weights` must be as same as the number of operations."
  116. f"Got {len(random_apply_weights)} and {len(self)}."
  117. )
  118. self.random_apply_weights = as_tensor(random_apply_weights or torch.ones((len(self),)))
  119. self.if_unsupported_ops = if_unsupported_ops
  120. def _read_random_apply(
  121. self, random_apply: Union[int, bool, Tuple[int, int]], max_length: int
  122. ) -> Union[Tuple[int, int], bool]:
  123. """Process the scenarios for random apply."""
  124. if isinstance(random_apply, (bool,)) and random_apply is False:
  125. random_apply = False
  126. elif isinstance(random_apply, (bool,)) and random_apply is True:
  127. random_apply = (max_length, max_length + 1)
  128. elif isinstance(random_apply, (int,)):
  129. random_apply = (random_apply, random_apply + 1)
  130. elif (
  131. isinstance(random_apply, (tuple,))
  132. and len(random_apply) == 2
  133. and isinstance(random_apply[0], (int,))
  134. and isinstance(random_apply[1], (int,))
  135. ):
  136. random_apply = (random_apply[0], random_apply[1] + 1)
  137. elif isinstance(random_apply, (tuple,)) and len(random_apply) == 1 and isinstance(random_apply[0], (int,)):
  138. random_apply = (random_apply[0], max_length + 1)
  139. else:
  140. raise ValueError(f"Non-readable random_apply. Got {random_apply}.")
  141. if random_apply is not False and not (
  142. isinstance(random_apply, (tuple,))
  143. and len(random_apply) == 2
  144. and isinstance(random_apply[0], (int,))
  145. and isinstance(random_apply[0], (int,))
  146. ):
  147. raise AssertionError(f"Expect a tuple of (int, int). Got {random_apply}.")
  148. return random_apply
  149. def get_random_forward_sequence(self, with_mix: bool = True) -> Tuple[Iterator[Tuple[str, Module]], bool]:
  150. """Get a forward sequence when random apply is in need.
  151. Args:
  152. with_mix: if to require a mix augmentation for the sequence.
  153. Note:
  154. Mix augmentations (e.g. RandomMixUp) will be only applied once even in a random forward.
  155. """
  156. if isinstance(self.random_apply, tuple):
  157. num_samples = int(torch.randint(*self.random_apply, (1,)).item())
  158. else:
  159. raise TypeError(f"random apply should be a tuple. Gotcha {type(self.random_apply)}")
  160. multinomial_weights = self.random_apply_weights.clone()
  161. # Mix augmentation can only be applied once per forward
  162. mix_indices = self.get_mix_augmentation_indices(self.named_children())
  163. # kick out the mix augmentations
  164. multinomial_weights[mix_indices] = 0
  165. indices = torch.multinomial(
  166. multinomial_weights,
  167. num_samples,
  168. # enable replacement if non-mix augmentation is less than required
  169. replacement=num_samples > multinomial_weights.sum().item(),
  170. )
  171. mix_added = False
  172. if with_mix and len(mix_indices) != 0:
  173. # Make the selection fair.
  174. if (torch.rand(1) < ((len(mix_indices) + len(indices)) / len(self))).item():
  175. indices[-1] = torch.multinomial((~multinomial_weights.bool()).float(), 1)
  176. indices = indices[torch.randperm(len(indices))]
  177. mix_added = True
  178. return self.get_children_by_indices(indices), mix_added
  179. def get_mix_augmentation_indices(self, named_modules: Iterator[Tuple[str, Module]]) -> List[int]:
  180. """Get all the mix augmentations since they are label-involved.
  181. Special operations needed for label-involved augmentations.
  182. """
  183. # NOTE: MixV2 will not be a special op in the future.
  184. return [idx for idx, (_, child) in enumerate(named_modules) if isinstance(child, K.MixAugmentationBaseV2)]
  185. def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
  186. if params is None:
  187. # Mix augmentation can only be applied once per forward
  188. mix_indices = self.get_mix_augmentation_indices(self.named_children())
  189. if self.random_apply:
  190. return self.get_random_forward_sequence()[0]
  191. if len(mix_indices) > 1:
  192. raise ValueError(
  193. "Multiple mix augmentation is prohibited without enabling random_apply."
  194. f"Detected {len(mix_indices)} mix augmentations."
  195. )
  196. return self.named_children()
  197. return self.get_children_by_params(params)
  198. def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
  199. named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence()
  200. params: List[ParamItem] = []
  201. mod_param: Union[Dict[str, Tensor], List[ParamItem]]
  202. for name, module in named_modules:
  203. if isinstance(module, (_AugmentationBase, K.MixAugmentationBaseV2, ImageSequentialBase)):
  204. mod_param = module.forward_parameters(batch_shape)
  205. param = ParamItem(name, mod_param)
  206. else:
  207. param = ParamItem(name, None)
  208. batch_shape = _get_new_batch_shape(param, batch_shape)
  209. params.append(param)
  210. return params
  211. def identity_matrix(self, input: Tensor) -> Tensor:
  212. """Return identity matrix."""
  213. return eye_like(3, input)
  214. def get_transformation_matrix(
  215. self,
  216. input: Tensor,
  217. params: Optional[List[ParamItem]] = None,
  218. recompute: bool = False,
  219. extra_args: Optional[Dict[str, Any]] = None,
  220. ) -> Optional[Tensor]:
  221. """Compute the transformation matrix according to the provided parameters.
  222. Args:
  223. input: the input tensor.
  224. params: params for the sequence.
  225. recompute: if to recompute the transformation matrix according to the params.
  226. default: False.
  227. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  228. """
  229. if params is None:
  230. raise NotImplementedError("requires params to be provided.")
  231. if extra_args is None:
  232. extra_args = {}
  233. named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence(params)
  234. # Define as 1 for broadcasting
  235. res_mat: Optional[Tensor] = None
  236. for (_, module), param in zip(named_modules, params if params is not None else []):
  237. if isinstance(module, (K.GeometricAugmentationBase2D,)) and isinstance(param.data, dict):
  238. ori_shape = input.shape
  239. try:
  240. input = module.transform_tensor(input)
  241. except ValueError:
  242. # Ignore error for 5-dim video
  243. pass
  244. # Standardize shape
  245. if recompute:
  246. flags = override_parameters(module.flags, extra_args, in_place=False)
  247. mat = module.generate_transformation_matrix(input, param.data, flags)
  248. elif module._transform_matrix is not None:
  249. mat = as_tensor(module._transform_matrix, device=input.device, dtype=input.dtype)
  250. else:
  251. raise RuntimeError(f"{module}._transform_matrix is None while `recompute=False`.")
  252. res_mat = mat if res_mat is None else mat @ res_mat
  253. input = module.transform_output_tensor(input, ori_shape)
  254. if module.keepdim and ori_shape != input.shape:
  255. res_mat = res_mat.squeeze()
  256. elif isinstance(module, (ImageSequentialBase,)):
  257. # If not augmentationSequential
  258. if isinstance(module, (K.AugmentationSequential,)) and not recompute:
  259. mat = as_tensor(module._transform_matrix, device=input.device, dtype=input.dtype)
  260. else:
  261. maybe_param_data = cast(Optional[List[ParamItem]], param.data)
  262. _mat = module.get_transformation_matrix(
  263. input, maybe_param_data, recompute=recompute, extra_args=extra_args
  264. )
  265. mat = module.identity_matrix(input) if _mat is None else _mat
  266. res_mat = mat if res_mat is None else mat @ res_mat
  267. return res_mat
  268. # TODO: Make this as a class property to avoid running every time.
  269. def is_intensity_only(self, strict: bool = True) -> bool:
  270. """Check if all transformations are intensity-based.
  271. Args:
  272. strict: if strict is False, it will allow non-augmentation Modules to be passed.
  273. e.g. `kornia.enhance.AdjustBrightness` will be recognized as non-intensity module
  274. if strict is set to True.
  275. Note: patch processing would break the continuity of labels (e.g. bbounding boxes, masks).
  276. """
  277. for arg in self.children():
  278. if isinstance(arg, (ImageSequential,)) and not arg.is_intensity_only(strict):
  279. return False
  280. elif isinstance(arg, (ImageSequential,)):
  281. pass
  282. elif isinstance(arg, K.IntensityAugmentationBase2D):
  283. pass
  284. elif strict:
  285. # disallow non-registered ops if in strict mode
  286. # TODO: add an ops register module
  287. return False
  288. return True
  289. def __call__(
  290. self,
  291. *inputs: Any,
  292. input_names_to_handle: Optional[List[Any]] = None,
  293. output_type: str = "tensor",
  294. **kwargs: Any,
  295. ) -> Any:
  296. """Overwrite the __call__ function to handle various inputs.
  297. Args:
  298. inputs: Inputs to operate on.
  299. input_names_to_handle: List of input names to convert, if None, handle all inputs.
  300. output_type: Desired output type ('tensor', 'numpy', or 'pil').
  301. kwargs: Additional arguments.
  302. Returns:
  303. Callable: Decorated function with converted input and output types.
  304. """
  305. # Wrap the forward method with the decorator
  306. if not self._disable_features:
  307. decorated_forward = self.convert_input_output(
  308. input_names_to_handle=input_names_to_handle, output_type=output_type
  309. )(super().__call__)
  310. _output_image = decorated_forward(*inputs, **kwargs)
  311. if output_type == "tensor":
  312. self._output_image = self._detach_tensor_to_cpu(_output_image)
  313. else:
  314. self._output_image = _output_image
  315. else:
  316. _output_image = super().__call__(*inputs, **kwargs)
  317. return _output_image
  318. def _get_new_batch_shape(param: ParamItem, batch_shape: torch.Size) -> torch.Size:
  319. """Get the new batch shape if the augmentation changes the image size.
  320. Note:
  321. Augmentations that change the image size must provide the parameter `output_size`.
  322. """
  323. data = param.data
  324. if data is None:
  325. return batch_shape
  326. # If data is a list, process all subitems (exit early if all subitems are None)
  327. if isinstance(data, list):
  328. for p in data:
  329. batch_shape = _get_new_batch_shape(p, batch_shape)
  330. return batch_shape
  331. # Carefully avoid evaluating expression multiple times; batch_prob is often a 1-element tensor
  332. if "output_size" in data:
  333. # Inline check for common PyTorch float tensor case
  334. batch_prob = data.get("batch_prob", None)
  335. if batch_prob is not None:
  336. # Avoid repeated indexing, always fetch scalar efficiently
  337. prob = batch_prob.item() if batch_prob.numel() == 1 else batch_prob[0].item()
  338. if prob <= 0.5:
  339. return batch_shape
  340. else:
  341. # batch_prob missing, fallback do not update shape
  342. return batch_shape
  343. # Mutate only last two dims
  344. new_batch_shape = list(batch_shape)
  345. new_batch_shape[-2:] = data["output_size"][0]
  346. return torch.Size(new_batch_shape)
  347. return batch_shape