augment.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import warnings
  18. from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
  19. import torch
  20. from kornia.augmentation._2d.base import RigidAffineAugmentationBase2D
  21. from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D
  22. from kornia.augmentation.base import _AugmentationBase
  23. from kornia.constants import DataKey, Resample
  24. from kornia.core import Module, Tensor
  25. from kornia.geometry.boxes import Boxes, VideoBoxes
  26. from kornia.geometry.keypoints import Keypoints, VideoKeypoints
  27. from kornia.utils import eye_like, is_autocast_enabled
  28. from .base import TransformMatrixMinIn
  29. from .image import ImageSequential
  30. from .ops import AugmentationSequentialOps, DataType
  31. from .params import ParamItem
  32. from .patch import PatchSequential
  33. from .video import VideoSequential
  34. __all__ = ["AugmentationSequential"]
  35. _BOXES_OPTIONS = {DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH}
  36. _KEYPOINTS_OPTIONS = {DataKey.KEYPOINTS}
  37. _IMG_OPTIONS = {DataKey.INPUT, DataKey.IMAGE}
  38. _MSK_OPTIONS = {DataKey.MASK}
  39. _CLS_OPTIONS = {DataKey.CLASS, DataKey.LABEL}
  40. MaskDataType = Union[Tensor, List[Tensor]]
  41. class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
  42. r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once.
  43. .. image:: _static/img/AugmentationSequential.png
  44. Args:
  45. *args: a list of kornia augmentation modules.
  46. data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask",
  47. "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label".
  48. same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise
  49. settings.
  50. keepdim: whether to keep the output shape the same as input (True) or broadcast it to the batch form (False).
  51. If None, it will not overwrite the function-wise settings.
  52. random_apply: randomly select a sublist (order agnostic) of args to apply transformation.
  53. If int, a fixed number of transformations will be selected.
  54. If (a,), x number of transformations (a <= x <= len(args)) will be selected.
  55. If (a, b), x number of transformations (a <= x <= b) will be selected.
  56. If True, the whole list of args will be processed as a sequence in a random order.
  57. If False, the whole list of args will be processed as a sequence in original order.
  58. transformation_matrix_mode: computation mode for the chained transformation matrix, via `.transform_matrix`
  59. attribute.
  60. If `silent`, transformation matrix will be computed silently and the non-rigid
  61. modules will be ignored as identity transformations.
  62. If `rigid`, transformation matrix will be computed silently and the non-rigid
  63. modules will trigger errors.
  64. If `skip`, transformation matrix will be totally ignored.
  65. extra_args: to control the behaviour for each datakeys. By default, masks are handled by nearest interpolation
  66. strategies.
  67. .. note::
  68. Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input"/"image" data key.
  69. It is not clear how to deal with the conversions of masks, bounding boxes and keypoints.
  70. .. note::
  71. See a working example `here <https://kornia.github.io/tutorials/nbs/data_augmentation_sequential.html>`__.
  72. Examples:
  73. >>> import kornia
  74. >>> input = torch.randn(2, 3, 5, 6)
  75. >>> mask = torch.ones(2, 3, 5, 6)
  76. >>> bbox = torch.tensor([[
  77. ... [1., 1.],
  78. ... [2., 1.],
  79. ... [2., 2.],
  80. ... [1., 2.],
  81. ... ]]).expand(2, 1, -1, -1)
  82. >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)
  83. >>> aug_list = AugmentationSequential(
  84. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  85. ... kornia.augmentation.RandomAffine(360, p=1.0),
  86. ... data_keys=["input", "mask", "bbox", "keypoints"],
  87. ... same_on_batch=False,
  88. ... random_apply=10,
  89. ... )
  90. >>> out = aug_list(input, mask, bbox, points)
  91. >>> [o.shape for o in out]
  92. [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])]
  93. >>> # apply the exact augmentation again.
  94. >>> out_rep = aug_list(input, mask, bbox, points, params=aug_list._params)
  95. >>> [(o == o_rep).all() for o, o_rep in zip(out, out_rep)]
  96. [tensor(True), tensor(True), tensor(True), tensor(True)]
  97. >>> # inverse the augmentations
  98. >>> out_inv = aug_list.inverse(*out)
  99. >>> [o.shape for o in out_inv]
  100. [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])]
  101. This example demonstrates the integration of VideoSequential and AugmentationSequential.
  102. >>> import kornia
  103. >>> input = torch.randn(2, 3, 5, 6)[None]
  104. >>> mask = torch.ones(2, 3, 5, 6)[None]
  105. >>> bbox = torch.tensor([[
  106. ... [1., 1.],
  107. ... [2., 1.],
  108. ... [2., 2.],
  109. ... [1., 2.],
  110. ... ]]).expand(2, 1, -1, -1)[None]
  111. >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None]
  112. >>> aug_list = AugmentationSequential(
  113. ... VideoSequential(
  114. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  115. ... kornia.augmentation.RandomAffine(360, p=1.0),
  116. ... ),
  117. ... data_keys=["input", "mask", "bbox", "keypoints"]
  118. ... )
  119. >>> out = aug_list(input, mask, bbox, points)
  120. >>> [o.shape for o in out] # doctest: +ELLIPSIS
  121. [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])]
  122. Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights``
  123. in ``AugmentationSequential``.
  124. >>> import kornia
  125. >>> input = torch.randn(2, 3, 5, 6)[None]
  126. >>> mask = torch.ones(2, 3, 5, 6)[None]
  127. >>> bbox = torch.tensor([[
  128. ... [1., 1.],
  129. ... [2., 1.],
  130. ... [2., 2.],
  131. ... [1., 2.],
  132. ... ]]).expand(2, 1, -1, -1)[None]
  133. >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None]
  134. >>> aug_list = AugmentationSequential(
  135. ... VideoSequential(
  136. ... kornia.augmentation.RandomAffine(360, p=1.0),
  137. ... ),
  138. ... VideoSequential(
  139. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  140. ... ),
  141. ... data_keys=["input", "mask", "bbox", "keypoints"],
  142. ... random_apply=1,
  143. ... random_apply_weights=[0.5, 0.3]
  144. ... )
  145. >>> out = aug_list(input, mask, bbox, points)
  146. >>> [o.shape for o in out] # doctest: +ELLIPSIS
  147. [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])]
  148. This example shows how to use a list of masks and boxes within AugmentationSequential
  149. >>> import kornia.augmentation as K
  150. >>> input = torch.randn(2, 3, 256, 256)
  151. >>> mask = [torch.ones(1, 3, 256, 256), torch.ones(1, 2, 256, 256)]
  152. >>> bbox = [
  153. ... torch.tensor([[28.0, 53.0, 143.0, 164.0], [254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]),
  154. ... torch.tensor([[254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]])
  155. ... ]
  156. >>> bbox = [Boxes.from_tensor(i).data for i in bbox]
  157. >>> aug_list = K.AugmentationSequential(
  158. ... K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  159. ... K.RandomHorizontalFlip(p=1.0),
  160. ... K.ImageSequential(K.RandomHorizontalFlip(p=1.0)),
  161. ... K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)),
  162. ... data_keys=["input", "mask", "bbox"],
  163. ... same_on_batch=False,
  164. ... random_apply=10,
  165. ... )
  166. >>> out = aug_list(input, mask, bbox)
  167. How to use a dictionary as input with AugmentationSequential? The dictionary keys that start with
  168. one of the available datakeys will be augmented accordingly. Otherwise, the dictionary item is passed
  169. without any augmentation.
  170. >>> import kornia.augmentation as K
  171. >>> img = torch.randn(1, 3, 256, 256)
  172. >>> mask = [torch.ones(1, 3, 256, 256), torch.ones(1, 2, 256, 256)]
  173. >>> bbox = [
  174. ... torch.tensor([[28.0, 53.0, 143.0, 164.0], [254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]),
  175. ... torch.tensor([[254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]])
  176. ... ]
  177. >>> bbox = [Boxes.from_tensor(i).data for i in bbox]
  178. >>> aug_dict = K.AugmentationSequential(
  179. ... K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  180. ... K.RandomHorizontalFlip(p=1.0),
  181. ... K.ImageSequential(K.RandomHorizontalFlip(p=1.0)),
  182. ... K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)),
  183. ... data_keys=None,
  184. ... same_on_batch=False,
  185. ... random_apply=10,
  186. ... )
  187. >>> data = {'image': img, 'mask': mask[0], 'mask-b': mask[1], 'bbox': bbox[0], 'bbox-other':bbox[1]}
  188. >>> out = aug_dict(data)
  189. >>> out.keys()
  190. dict_keys(['image', 'mask', 'mask-b', 'bbox', 'bbox-other'])
  191. """
  192. input_dtype = None
  193. mask_dtype = None
  194. def __init__(
  195. self,
  196. *args: Union[_AugmentationBase, ImageSequential],
  197. data_keys: Optional[Union[Sequence[str], Sequence[int], Sequence[DataKey]]] = (DataKey.INPUT,),
  198. same_on_batch: Optional[bool] = None,
  199. keepdim: Optional[bool] = None,
  200. random_apply: Union[int, bool, Tuple[int, int]] = False,
  201. random_apply_weights: Optional[List[float]] = None,
  202. transformation_matrix_mode: str = "silent",
  203. extra_args: Optional[Dict[DataKey, Dict[str, Any]]] = None,
  204. ) -> None:
  205. self._transform_matrix: Optional[Tensor]
  206. self._transform_matrices: List[Optional[Tensor]] = []
  207. super().__init__(
  208. *args,
  209. same_on_batch=same_on_batch,
  210. keepdim=keepdim,
  211. random_apply=random_apply,
  212. random_apply_weights=random_apply_weights,
  213. )
  214. self._parse_transformation_matrix_mode(transformation_matrix_mode)
  215. self._valid_ops_for_transform_computation: Tuple[Any, ...] = (
  216. RigidAffineAugmentationBase2D,
  217. RigidAffineAugmentationBase3D,
  218. AugmentationSequential,
  219. )
  220. self.data_keys: Optional[List[DataKey]]
  221. if data_keys is not None:
  222. self.data_keys = [DataKey.get(inp) for inp in data_keys]
  223. else:
  224. self.data_keys = data_keys
  225. if self.data_keys:
  226. if any(in_type not in DataKey for in_type in self.data_keys):
  227. raise AssertionError(f"`data_keys` must be in {DataKey}. Got {self.data_keys}.")
  228. if self.data_keys[0] != DataKey.INPUT:
  229. raise NotImplementedError(f"The first input must be {DataKey.INPUT}.")
  230. self.transform_op = AugmentationSequentialOps(self.data_keys)
  231. self.contains_video_sequential: bool = False
  232. self.contains_3d_augmentation: bool = False
  233. for arg in args:
  234. if isinstance(arg, PatchSequential) and not arg.is_intensity_only():
  235. warnings.warn(
  236. "Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask.", stacklevel=1
  237. )
  238. if isinstance(arg, VideoSequential):
  239. self.contains_video_sequential = True
  240. # NOTE: only for images are supported for 3D.
  241. if isinstance(arg, AugmentationBase3D):
  242. self.contains_3d_augmentation = True
  243. self._transform_matrix = None
  244. self.extra_args = extra_args or {DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}}
  245. def clear_state(self) -> None:
  246. self._reset_transform_matrix_state()
  247. return super().clear_state()
  248. def _update_transform_matrix_for_valid_op(self, module: Module) -> None:
  249. self._transform_matrices.append(module.transform_matrix)
  250. def identity_matrix(self, input: Tensor) -> Tensor:
  251. """Return identity matrix."""
  252. if self.contains_3d_augmentation:
  253. return eye_like(4, input)
  254. else:
  255. return eye_like(3, input)
  256. def inverse( # type: ignore[override]
  257. self,
  258. *args: Union[DataType, Dict[str, DataType]],
  259. params: Optional[List[ParamItem]] = None,
  260. data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
  261. ) -> Union[DataType, List[DataType], Dict[str, DataType]]:
  262. """Reverse the transformation applied.
  263. Number of input tensors must align with the number of``data_keys``. If ``data_keys`` is not set, use
  264. ``self.data_keys`` by default.
  265. """
  266. original_keys = None
  267. if len(args) == 1 and isinstance(args[0], dict):
  268. original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0])
  269. # args here should already be `DataType`
  270. # NOTE: how to right type to: unpacked args <-> tuple of args to unpack
  271. # issue with `self._preproc_dict_data` return args type
  272. self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys)
  273. self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys) # type: ignore
  274. in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys) # type: ignore
  275. if params is None:
  276. if self._params is None:
  277. raise ValueError(
  278. "No parameters available for inversing, please run a forward pass first "
  279. "or passing valid params into this function."
  280. )
  281. params = self._params
  282. outputs: List[DataType] = in_args
  283. for param in params[::-1]:
  284. module = self.get_submodule(param.name)
  285. outputs = self.transform_op.inverse( # type: ignore
  286. *outputs, module=module, param=param, extra_args=self.extra_args
  287. )
  288. if not isinstance(outputs, (list, tuple)):
  289. # Make sure we are unpacking a list whilst post-proc
  290. outputs = [outputs]
  291. outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore
  292. if isinstance(original_keys, tuple):
  293. result = {k: v for v, k in zip(outputs, original_keys)}
  294. if invalid_data:
  295. result.update(invalid_data)
  296. return result
  297. if len(outputs) == 1 and isinstance(outputs, list):
  298. return outputs[0]
  299. return outputs
  300. def _validate_args_datakeys(self, *args: DataType, data_keys: List[DataKey]) -> None:
  301. if len(args) != len(data_keys):
  302. raise AssertionError(
  303. f"The number of inputs must align with the number of data_keys. Got {len(args)} and {len(data_keys)}."
  304. )
  305. # TODO: validate args batching, and its consistency
  306. def _arguments_preproc(self, *args: DataType, data_keys: List[DataKey]) -> List[DataType]:
  307. inp: List[DataType] = []
  308. for arg, dcate in zip(args, data_keys):
  309. if DataKey.get(dcate) in _IMG_OPTIONS:
  310. arg = cast(Tensor, arg)
  311. self.input_dtype = arg.dtype
  312. inp.append(arg)
  313. elif DataKey.get(dcate) in _MSK_OPTIONS:
  314. if isinstance(inp, list):
  315. arg = cast(List[Tensor], arg)
  316. self.mask_dtype = arg[0].dtype
  317. else:
  318. arg = cast(Tensor, arg)
  319. self.mask_dtype = arg.dtype
  320. inp.append(self._preproc_mask(arg))
  321. elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
  322. inp.append(self._preproc_keypoints(arg, dcate))
  323. elif DataKey.get(dcate) in _BOXES_OPTIONS:
  324. inp.append(self._preproc_boxes(arg, dcate))
  325. elif DataKey.get(dcate) in _CLS_OPTIONS:
  326. inp.append(arg)
  327. else:
  328. raise NotImplementedError(f"input type of {dcate} is not implemented.")
  329. return inp
  330. def _arguments_postproc(
  331. self, in_args: List[DataType], out_args: List[DataType], data_keys: List[DataKey]
  332. ) -> List[DataType]:
  333. out: List[DataType] = []
  334. for in_arg, out_arg, dcate in zip(in_args, out_args, data_keys):
  335. if DataKey.get(dcate) in _IMG_OPTIONS:
  336. # It is tensor type already.
  337. out.append(out_arg)
  338. # TODO: may add the float to integer (for masks), etc.
  339. elif DataKey.get(dcate) in _MSK_OPTIONS:
  340. _out_m = self._postproc_mask(cast(MaskDataType, out_arg))
  341. out.append(_out_m)
  342. elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
  343. _out_k = self._postproc_keypoint(in_arg, cast(Keypoints, out_arg), dcate)
  344. if is_autocast_enabled() and isinstance(in_arg, (Tensor, Keypoints)):
  345. if isinstance(_out_k, list):
  346. _out_k = [i.type(in_arg.dtype) for i in _out_k]
  347. else:
  348. _out_k = _out_k.type(in_arg.dtype)
  349. out.append(_out_k)
  350. elif DataKey.get(dcate) in _BOXES_OPTIONS:
  351. _out_b = self._postproc_boxes(in_arg, cast(Boxes, out_arg), dcate)
  352. if is_autocast_enabled() and isinstance(in_arg, (Tensor, Boxes)):
  353. if isinstance(_out_b, list):
  354. _out_b = [i.type(in_arg.dtype) for i in _out_b]
  355. else:
  356. _out_b = _out_b.type(in_arg.dtype)
  357. out.append(_out_b)
  358. elif DataKey.get(dcate) in _CLS_OPTIONS:
  359. out.append(out_arg)
  360. else:
  361. raise NotImplementedError(f"input type of {dcate} is not implemented.")
  362. return out
  363. def forward( # type: ignore[override]
  364. self,
  365. *args: Union[DataType, Dict[str, DataType]],
  366. params: Optional[List[ParamItem]] = None,
  367. data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
  368. ) -> Union[DataType, List[DataType], Dict[str, DataType]]:
  369. """Compute multiple tensors simultaneously according to ``self.data_keys``."""
  370. self.clear_state()
  371. # Unpack/handle dictionary args
  372. original_keys = None
  373. if len(args) == 1 and isinstance(args[0], dict):
  374. original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0])
  375. self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys)
  376. self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys) # type: ignore
  377. in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys) # type: ignore
  378. if params is None:
  379. # image data must exist if params is not provided.
  380. if DataKey.INPUT in self.transform_op.data_keys:
  381. inp = in_args[self.transform_op.data_keys.index(DataKey.INPUT)]
  382. if not isinstance(inp, (Tensor,)):
  383. raise ValueError(f"`INPUT` should be a tensor but `{type(inp)}` received.")
  384. # A video input shall be BCDHW while an image input shall be BCHW
  385. if self.contains_video_sequential or self.contains_3d_augmentation:
  386. _, out_shape = self.autofill_dim(inp, dim_range=(3, 5))
  387. else:
  388. _, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
  389. params = self.forward_parameters(out_shape)
  390. else:
  391. raise ValueError("`params` must be provided whilst INPUT is not in data_keys.")
  392. outputs: Union[Tensor, List[DataType]] = in_args
  393. for param in params:
  394. module = self.get_submodule(param.name)
  395. outputs = self.transform_op.transform( # type: ignore
  396. *outputs, module=module, param=param, extra_args=self.extra_args
  397. )
  398. if not isinstance(outputs, (list, tuple)):
  399. # Make sure we are unpacking a list whilst post-proc
  400. outputs = [outputs]
  401. self._update_transform_matrix_by_module(module)
  402. outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore
  403. # Restore it back
  404. self.transform_op.data_keys = self.data_keys
  405. self._params = params
  406. if isinstance(original_keys, tuple):
  407. result = {k: v for v, k in zip(outputs, original_keys)}
  408. if invalid_data:
  409. result.update(invalid_data)
  410. return result
  411. if len(outputs) == 1 and isinstance(outputs, list):
  412. return outputs[0]
  413. return outputs
  414. def __call__(
  415. self,
  416. *inputs: Any,
  417. input_names_to_handle: Optional[List[Any]] = None,
  418. output_type: str = "tensor",
  419. **kwargs: Any,
  420. ) -> Any:
  421. """Overwrite the __call__ function to handle various inputs.
  422. Args:
  423. inputs: Inputs to operate on.
  424. input_names_to_handle: List of input names to convert, if None, handle all inputs.
  425. output_type: Desired output type ('tensor', 'numpy', or 'pil').
  426. kwargs: Additional arguments.
  427. Returns:
  428. Callable: Decorated function with converted input and output types.
  429. """
  430. # Wrap the forward method with the decorator
  431. if not self._disable_features:
  432. # TODO: Some more behaviour for AugmentationSequential needs to be revisited later
  433. # e.g. We convert only images, etc.
  434. decorated_forward = self.convert_input_output(
  435. input_names_to_handle=input_names_to_handle, output_type=output_type
  436. )(super(ImageSequential, self).__call__)
  437. _output_image = decorated_forward(*inputs, **kwargs)
  438. in_data_keys: Optional[List[DataKey]]
  439. if len(inputs) == 1 and isinstance(inputs[0], dict):
  440. original_keys, in_data_keys, inputs, _invalid_data = self._preproc_dict_data(inputs[0])
  441. else:
  442. in_data_keys = kwargs.get("data_keys", self.data_keys)
  443. data_keys = self.transform_op.preproc_datakeys(in_data_keys)
  444. if len(data_keys) > 1 and DataKey.INPUT in data_keys:
  445. # NOTE: we may update it later for more supports of drawing boxes, etc.
  446. idx = data_keys.index(DataKey.INPUT)
  447. if output_type == "tensor":
  448. self._output_image = _output_image
  449. if isinstance(_output_image, dict):
  450. self._output_image[original_keys[idx]] = _output_image[original_keys[idx]]
  451. else:
  452. self._output_image[idx] = _output_image[idx]
  453. elif isinstance(_output_image, dict):
  454. self._output_image[original_keys[idx]] = _output_image[original_keys[idx]]
  455. else:
  456. self._output_image[idx] = _output_image[idx]
  457. else:
  458. self._output_image = _output_image
  459. else:
  460. _output_image = super(ImageSequential, self).__call__(*inputs, **kwargs)
  461. return _output_image
  462. def _preproc_dict_data(
  463. self, data: Dict[str, DataType]
  464. ) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...], Optional[Dict[str, Any]]]:
  465. if self.data_keys is not None:
  466. raise ValueError("If you are using a dictionary as input, the data_keys should be None.")
  467. keys = tuple(data.keys())
  468. data_keys, invalid_keys = self._read_datakeys_from_dict(keys)
  469. invalid_data = {i: data.pop(i) for i in invalid_keys} if invalid_keys else None
  470. keys = tuple(k for k in keys if k not in invalid_keys) if invalid_keys else keys
  471. data_unpacked = tuple(data.values())
  472. return keys, data_keys, data_unpacked, invalid_data
  473. def _read_datakeys_from_dict(self, keys: Sequence[str]) -> Tuple[List[DataKey], Optional[List[str]]]:
  474. def retrieve_key(key: str) -> DataKey:
  475. """Try to retrieve the datakey value by matching `<datakey>*`."""
  476. # Alias cases, like INPUT, will not be get by the enum iterator.
  477. if key.upper().startswith("INPUT"):
  478. return DataKey.INPUT
  479. for dk in DataKey:
  480. if key.upper() in {"BBOX_XYXY", "BBOX_XYWH"}:
  481. return DataKey.get(key.upper())
  482. if key.upper().startswith(dk.name):
  483. return DataKey.get(dk.name)
  484. allowed_dk = " | ".join(f"`{d.name}`" for d in DataKey)
  485. raise ValueError(
  486. f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`"
  487. )
  488. valid_data_keys = []
  489. invalid_keys = []
  490. for k in keys:
  491. try:
  492. valid_data_keys.append(DataKey.get(retrieve_key(k)))
  493. except ValueError:
  494. invalid_keys.append(k)
  495. return valid_data_keys, invalid_keys
  496. def _preproc_mask(self, arg: MaskDataType) -> MaskDataType:
  497. if isinstance(arg, list):
  498. new_arg = []
  499. for a in arg:
  500. a_new = a.to(self.input_dtype) if self.input_dtype else a.to(torch.float)
  501. new_arg.append(a_new)
  502. return new_arg
  503. else:
  504. arg = arg.to(self.input_dtype) if self.input_dtype else arg.to(torch.float)
  505. return arg
  506. def _postproc_mask(self, arg: MaskDataType) -> MaskDataType:
  507. if isinstance(arg, list):
  508. new_arg = []
  509. for a in arg:
  510. a_new = a.to(self.mask_dtype) if self.mask_dtype else a.to(torch.float)
  511. new_arg.append(a_new)
  512. return new_arg
  513. else:
  514. arg = arg.to(self.mask_dtype) if self.mask_dtype else arg.to(torch.float)
  515. return arg
  516. def _preproc_boxes(self, arg: DataType, dcate: DataKey) -> Boxes:
  517. if DataKey.get(dcate) in [DataKey.BBOX]:
  518. mode = "vertices_plus"
  519. elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]:
  520. mode = "xyxy_plus"
  521. elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]:
  522. mode = "xywh"
  523. else:
  524. raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.")
  525. if isinstance(arg, (Boxes,)):
  526. return arg
  527. elif self.contains_video_sequential:
  528. arg = cast(Tensor, arg)
  529. return VideoBoxes.from_tensor(arg)
  530. elif self.contains_3d_augmentation:
  531. raise NotImplementedError("3D box handlers are not yet supported.")
  532. else:
  533. arg = cast(Tensor, arg)
  534. return Boxes.from_tensor(arg, mode=mode)
  535. def _postproc_boxes(self, in_arg: DataType, out_arg: Boxes, dcate: DataKey) -> Union[Tensor, List[Tensor], Boxes]:
  536. if DataKey.get(dcate) in [DataKey.BBOX]:
  537. mode = "vertices_plus"
  538. elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]:
  539. mode = "xyxy_plus"
  540. elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]:
  541. mode = "xywh"
  542. else:
  543. raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.")
  544. # TODO: handle 3d scenarios
  545. if isinstance(in_arg, (Boxes,)):
  546. return out_arg
  547. else:
  548. return out_arg.to_tensor(mode=mode)
  549. def _preproc_keypoints(self, arg: DataType, dcate: DataKey) -> Keypoints:
  550. dtype = None
  551. if self.contains_video_sequential:
  552. arg = cast(Union[Tensor, List[Tensor]], arg)
  553. if isinstance(arg, list):
  554. if not torch.is_floating_point(arg[0]):
  555. dtype = arg[0].dtype
  556. arg = [a.float() for a in arg]
  557. elif not torch.is_floating_point(arg):
  558. dtype = arg.dtype
  559. arg = arg.float()
  560. video_result = VideoKeypoints.from_tensor(arg)
  561. return video_result.type(dtype) if dtype else video_result
  562. elif self.contains_3d_augmentation:
  563. raise NotImplementedError("3D keypoint handlers are not yet supported.")
  564. elif isinstance(arg, (Keypoints,)):
  565. return arg
  566. else:
  567. arg = cast(Tensor, arg)
  568. if not torch.is_floating_point(arg):
  569. dtype = arg.dtype
  570. arg = arg.float()
  571. # TODO: Add List[Tensor] in the future.
  572. result = Keypoints.from_tensor(arg)
  573. return result.type(dtype) if dtype else result
  574. def _postproc_keypoint(
  575. self, in_arg: DataType, out_arg: Keypoints, dcate: DataKey
  576. ) -> Union[Tensor, List[Tensor], Keypoints]:
  577. if isinstance(in_arg, (Keypoints,)):
  578. return out_arg
  579. else:
  580. return out_arg.to_tensor()