ops.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import copy
  18. from abc import ABCMeta, abstractmethod
  19. from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
  20. from typing_extensions import ParamSpec
  21. import kornia.augmentation as K
  22. from kornia.augmentation.base import _AugmentationBase
  23. from kornia.constants import DataKey
  24. from kornia.core import Module, Tensor
  25. from kornia.geometry.boxes import Boxes
  26. from kornia.geometry.keypoints import Keypoints
  27. from .params import ParamItem
  28. DataType = Union[Tensor, List[Tensor], Boxes, Keypoints]
  29. # NOTE: shouldn't this SequenceDataType alias be equals to List[DataType]?
  30. SequenceDataType = Union[List[Tensor], List[List[Tensor]], List[Boxes], List[Keypoints]]
  31. T = TypeVar("T")
  32. class SequentialOpsInterface(Generic[T], metaclass=ABCMeta):
  33. """Abstract interface for applying and inversing transformations."""
  34. @classmethod
  35. def get_instance_module_param(cls, param: ParamItem) -> Dict[str, Tensor]:
  36. if isinstance(param, ParamItem) and isinstance(param.data, dict):
  37. _params = param.data
  38. else:
  39. raise TypeError(f"Expected param (ParamItem.data) be a dictionary. Gotcha {param}.")
  40. return _params
  41. @classmethod
  42. def get_sequential_module_param(cls, param: ParamItem) -> List[ParamItem]:
  43. if isinstance(param, ParamItem) and isinstance(param.data, list):
  44. _params = param.data
  45. else:
  46. raise TypeError(f"Expected param (ParamItem.data) be a list. Gotcha {param}.")
  47. return _params
  48. @classmethod
  49. @abstractmethod
  50. def transform(cls, input: T, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None) -> T:
  51. """Apply a transformation with respect to the parameters.
  52. Args:
  53. input: the input tensor.
  54. module: any torch Module but only kornia augmentation modules will count
  55. to apply transformations.
  56. param: the corresponding parameters to the module.
  57. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  58. """
  59. raise NotImplementedError
  60. @classmethod
  61. @abstractmethod
  62. def inverse(cls, input: T, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None) -> T:
  63. """Inverse a transformation with respect to the parameters.
  64. Args:
  65. input: the input tensor.
  66. module: any torch Module but only kornia augmentation modules will count
  67. to apply transformations.
  68. param: the corresponding parameters to the module.
  69. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  70. """
  71. raise NotImplementedError
  72. class AugmentationSequentialOps:
  73. def __init__(self, data_keys: Optional[List[DataKey]]) -> None:
  74. self._data_keys = data_keys
  75. @property
  76. def data_keys(self) -> Optional[List[DataKey]]:
  77. return self._data_keys
  78. @data_keys.setter
  79. def data_keys(self, data_keys: Optional[Union[List[DataKey], List[str], List[int]]]) -> None:
  80. if data_keys:
  81. self._data_keys = [DataKey.get(inp) for inp in data_keys]
  82. else:
  83. self._data_keys = None
  84. def preproc_datakeys(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None) -> List[DataKey]:
  85. if data_keys is None:
  86. if isinstance(self.data_keys, list):
  87. return self.data_keys
  88. raise ValueError("Sequential ops needs data keys to be able to process.")
  89. else:
  90. return [DataKey.get(inp) for inp in data_keys]
  91. def _get_op(self, data_key: DataKey) -> Type[SequentialOpsInterface[Any]]:
  92. """Return the corresponding operation given a data key."""
  93. if data_key == DataKey.INPUT:
  94. return InputSequentialOps
  95. if data_key == DataKey.MASK:
  96. return MaskSequentialOps
  97. if data_key in {DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY}:
  98. return BoxSequentialOps
  99. if data_key == DataKey.KEYPOINTS:
  100. return KeypointSequentialOps
  101. if data_key == DataKey.CLASS:
  102. return ClassSequentialOps
  103. raise RuntimeError(f"Operation for `{data_key.name}` is not found.")
  104. def transform(
  105. self,
  106. *arg: DataType,
  107. module: Module,
  108. param: ParamItem,
  109. extra_args: Dict[DataKey, Dict[str, Any]],
  110. data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
  111. ) -> Union[DataType, SequenceDataType]:
  112. _data_keys = self.preproc_datakeys(data_keys)
  113. if isinstance(module, K.RandomTransplantation):
  114. # For transforms which require the full input to calculate the parameters (e.g. RandomTransplantation)
  115. param = ParamItem(
  116. name=param.name,
  117. data=module.params_from_input(
  118. *arg, # type: ignore[arg-type]
  119. data_keys=_data_keys,
  120. params=param.data, # type: ignore[arg-type]
  121. extra_args=extra_args,
  122. ),
  123. )
  124. outputs = []
  125. for inp, dcate in zip(arg, _data_keys):
  126. op = self._get_op(dcate)
  127. extra_arg = extra_args.get(dcate, {})
  128. if dcate.name == "MASK" and isinstance(inp, list):
  129. outputs.append(MaskSequentialOps.transform_list(inp, module, param=param, extra_args=extra_arg))
  130. else:
  131. outputs.append(op.transform(inp, module, param=param, extra_args=extra_arg))
  132. if len(outputs) == 1 and isinstance(outputs, (list, tuple)):
  133. return outputs[0]
  134. return outputs
  135. def inverse(
  136. self,
  137. *arg: DataType,
  138. module: Module,
  139. param: ParamItem,
  140. extra_args: Dict[DataKey, Dict[str, Any]],
  141. data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
  142. ) -> Union[DataType, SequenceDataType]:
  143. _data_keys = self.preproc_datakeys(data_keys)
  144. outputs = []
  145. for inp, dcate in zip(arg, _data_keys):
  146. op = self._get_op(dcate)
  147. extra_arg = extra_args[dcate] if dcate in extra_args else {}
  148. outputs.append(op.inverse(inp, module, param=param, extra_args=extra_arg))
  149. if len(outputs) == 1 and isinstance(outputs, (list, tuple)):
  150. return outputs[0]
  151. return outputs
  152. P = ParamSpec("P")
  153. def make_input_only_sequential(module: "K.container.ImageSequentialBase") -> Callable[P, Tensor]:
  154. """Disable all other additional inputs (e.g. ) for ImageSequential."""
  155. def f(*args: P.args, **kwargs: P.kwargs) -> Tensor:
  156. return module(*args, **kwargs)
  157. return f
  158. def get_geometric_only_param(module: "K.container.ImageSequentialBase", param: List[ParamItem]) -> List[ParamItem]:
  159. """Return geometry param."""
  160. named_modules = module.get_forward_sequence(param)
  161. res: List[ParamItem] = []
  162. for (_, mod), p in zip(named_modules, param):
  163. if isinstance(mod, (K.GeometricAugmentationBase2D, K.GeometricAugmentationBase3D)):
  164. res.append(p)
  165. return res
  166. class InputSequentialOps(SequentialOpsInterface[Tensor]):
  167. @classmethod
  168. def transform(
  169. cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  170. ) -> Tensor:
  171. if extra_args is None:
  172. extra_args = {}
  173. if isinstance(module, (_AugmentationBase, K.MixAugmentationBaseV2)):
  174. input = module(input, params=cls.get_instance_module_param(param), data_keys=[DataKey.INPUT], **extra_args)
  175. elif isinstance(module, (K.container.ImageSequentialBase,)):
  176. input = module.transform_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
  177. elif isinstance(module, (K.auto.operations.OperationBase,)):
  178. input = module(input, params=cls.get_instance_module_param(param))
  179. else:
  180. if param.data is not None:
  181. raise AssertionError(f"Non-augmentaion operation {param.name} require empty parameters. Got {param}.")
  182. input = module(input)
  183. return input
  184. @classmethod
  185. def inverse(
  186. cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  187. ) -> Tensor:
  188. if extra_args is None:
  189. extra_args = {}
  190. if isinstance(module, K.GeometricAugmentationBase2D):
  191. input = module.inverse(input, params=cls.get_instance_module_param(param), **extra_args)
  192. elif isinstance(module, (K.GeometricAugmentationBase3D,)):
  193. raise NotImplementedError(
  194. "The support for 3d inverse operations are not yet supported. You are welcome to file a PR in our repo."
  195. )
  196. elif isinstance(module, (K.auto.operations.OperationBase,)):
  197. return InputSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
  198. elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
  199. input = module.inverse_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
  200. elif isinstance(module, K.container.ImageSequentialBase):
  201. input = module.inverse_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
  202. return input
  203. class ClassSequentialOps(SequentialOpsInterface[Tensor]):
  204. """Apply and inverse transformations for class labels if needed."""
  205. @classmethod
  206. def transform(
  207. cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  208. ) -> Tensor:
  209. if isinstance(module, K.MixAugmentationBaseV2):
  210. raise NotImplementedError(
  211. "The support for class labels for mix augmentations that change the class label is not yet supported."
  212. )
  213. return input
  214. @classmethod
  215. def inverse(
  216. cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  217. ) -> Tensor:
  218. return input
  219. class MaskSequentialOps(SequentialOpsInterface[Tensor]):
  220. """Apply and inverse transformations for mask tensors."""
  221. @classmethod
  222. def transform(
  223. cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  224. ) -> Tensor:
  225. """Apply a transformation with respect to the parameters.
  226. Args:
  227. input: the input tensor.
  228. module: any torch Module but only kornia augmentation modules will count
  229. to apply transformations.
  230. param: the corresponding parameters to the module.
  231. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  232. """
  233. if extra_args is None:
  234. extra_args = {}
  235. if isinstance(module, (K.GeometricAugmentationBase2D,)):
  236. input = module.transform_masks(
  237. input,
  238. params=cls.get_instance_module_param(param),
  239. flags=module.flags,
  240. transform=module.transform_matrix,
  241. **extra_args,
  242. )
  243. elif isinstance(module, (K.GeometricAugmentationBase3D,)):
  244. raise NotImplementedError(
  245. "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
  246. )
  247. elif isinstance(module, K.RandomTransplantation):
  248. input = module(input, params=cls.get_instance_module_param(param), data_keys=[DataKey.MASK], **extra_args)
  249. elif isinstance(module, (_AugmentationBase)):
  250. input = module.transform_masks(
  251. input, params=cls.get_instance_module_param(param), flags=module.flags, **extra_args
  252. )
  253. elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
  254. input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
  255. elif isinstance(module, K.container.ImageSequentialBase):
  256. input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
  257. elif isinstance(module, (K.auto.operations.OperationBase,)):
  258. input = MaskSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)
  259. return input
  260. @classmethod
  261. def transform_list(
  262. cls, input: List[Tensor], module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  263. ) -> List[Tensor]:
  264. """Apply a transformation with respect to the parameters.
  265. Args:
  266. input: list of input tensors.
  267. module: any torch Module but only kornia augmentation modules will count
  268. to apply transformations.
  269. param: the corresponding parameters to the module.
  270. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  271. """
  272. if extra_args is None:
  273. extra_args = {}
  274. if isinstance(module, (K.GeometricAugmentationBase2D,)):
  275. tfm_input = []
  276. params = cls.get_instance_module_param(param)
  277. params_i = copy.deepcopy(params)
  278. for i, inp in enumerate(input):
  279. params_i["batch_prob"] = params["batch_prob"][i]
  280. tfm_inp = module.transform_masks(
  281. inp, params=params_i, flags=module.flags, transform=module.transform_matrix, **extra_args
  282. )
  283. tfm_input.append(tfm_inp)
  284. input = tfm_input
  285. elif isinstance(module, (K.GeometricAugmentationBase3D,)):
  286. raise NotImplementedError(
  287. "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
  288. )
  289. elif isinstance(module, (_AugmentationBase)):
  290. tfm_input = []
  291. params = cls.get_instance_module_param(param)
  292. params_i = copy.deepcopy(params)
  293. for i, inp in enumerate(input):
  294. params_i["batch_prob"] = params["batch_prob"][i]
  295. tfm_inp = module.transform_masks(inp, params=params_i, flags=module.flags, **extra_args)
  296. tfm_input.append(tfm_inp)
  297. input = tfm_input
  298. elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
  299. tfm_input = []
  300. seq_params = cls.get_sequential_module_param(param)
  301. for inp in input:
  302. tfm_inp = module.transform_masks(inp, params=seq_params, extra_args=extra_args)
  303. tfm_input.append(tfm_inp)
  304. input = tfm_input
  305. elif isinstance(module, K.container.ImageSequentialBase):
  306. tfm_input = []
  307. seq_params = cls.get_sequential_module_param(param)
  308. for inp in input:
  309. tfm_inp = module.transform_masks(inp, params=seq_params, extra_args=extra_args)
  310. tfm_input.append(tfm_inp)
  311. input = tfm_input
  312. elif isinstance(module, (K.auto.operations.OperationBase,)):
  313. raise NotImplementedError(
  314. "The support for list of masks under auto operations are not yet supported. You are welcome to file a"
  315. " PR in our repo."
  316. )
  317. return input
  318. @classmethod
  319. def inverse(
  320. cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  321. ) -> Tensor:
  322. """Inverse a transformation with respect to the parameters.
  323. Args:
  324. input: the input tensor.
  325. module: any torch Module but only kornia augmentation modules will count
  326. to apply transformations.
  327. param: the corresponding parameters to the module.
  328. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  329. """
  330. if extra_args is None:
  331. extra_args = {}
  332. if isinstance(module, (K.GeometricAugmentationBase2D,)):
  333. if module.transform_matrix is None:
  334. raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
  335. transform = module.compute_inverse_transformation(module.transform_matrix)
  336. input = module.inverse_masks(
  337. input,
  338. params=cls.get_instance_module_param(param),
  339. flags=module.flags,
  340. transform=transform,
  341. **extra_args,
  342. )
  343. elif isinstance(module, (K.GeometricAugmentationBase3D,)):
  344. raise NotImplementedError(
  345. "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
  346. )
  347. elif isinstance(module, K.container.ImageSequentialBase):
  348. input = module.inverse_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
  349. elif isinstance(module, (K.auto.operations.OperationBase,)):
  350. input = MaskSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
  351. return input
  352. class BoxSequentialOps(SequentialOpsInterface[Boxes]):
  353. """Apply and inverse transformations for bounding box tensors.
  354. This is for transform boxes in the format (B, N, 4, 2).
  355. """
  356. @classmethod
  357. def transform(
  358. cls, input: Boxes, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  359. ) -> Boxes:
  360. """Apply a transformation with respect to the parameters.
  361. Args:
  362. input: the input tensor, (B, N, 4, 2) or (B, 4, 2).
  363. module: any torch Module but only kornia augmentation modules will count
  364. to apply transformations.
  365. param: the corresponding parameters to the module.
  366. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  367. """
  368. if extra_args is None:
  369. extra_args = {}
  370. _input = input.clone()
  371. if isinstance(module, (K.GeometricAugmentationBase2D,)):
  372. _input = module.transform_boxes(
  373. _input,
  374. cls.get_instance_module_param(param),
  375. module.flags,
  376. transform=module.transform_matrix,
  377. **extra_args,
  378. )
  379. elif isinstance(module, (K.GeometricAugmentationBase3D,)):
  380. raise NotImplementedError(
  381. "The support for 3d box operations are not yet supported. You are welcome to file a PR in our repo."
  382. )
  383. elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
  384. _input = module.transform_boxes(
  385. _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
  386. )
  387. elif isinstance(module, K.container.ImageSequentialBase):
  388. _input = module.transform_boxes(
  389. _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
  390. )
  391. elif isinstance(module, (K.auto.operations.OperationBase,)):
  392. return BoxSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)
  393. return _input
  394. @classmethod
  395. def inverse(
  396. cls, input: Boxes, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  397. ) -> Boxes:
  398. """Inverse a transformation with respect to the parameters.
  399. Args:
  400. input: the input tensor.
  401. module: any torch Module but only kornia augmentation modules will count
  402. to apply transformations.
  403. param: the corresponding parameters to the module.
  404. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  405. """
  406. if extra_args is None:
  407. extra_args = {}
  408. _input = input.clone()
  409. if isinstance(module, (K.GeometricAugmentationBase2D,)):
  410. if module.transform_matrix is None:
  411. raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
  412. transform = module.compute_inverse_transformation(module.transform_matrix)
  413. _input = module.inverse_boxes(
  414. _input,
  415. param.data, # type: ignore[arg-type]
  416. module.flags,
  417. transform=transform,
  418. **extra_args,
  419. )
  420. elif isinstance(module, (K.GeometricAugmentationBase3D,)):
  421. raise NotImplementedError(
  422. "The support for 3d box operations are not yet supported. You are welcome to file a PR in our repo."
  423. )
  424. elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
  425. _input = module.inverse_boxes(_input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
  426. elif isinstance(module, K.container.ImageSequentialBase):
  427. _input = module.inverse_boxes(_input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
  428. elif isinstance(module, (K.auto.operations.OperationBase,)):
  429. return BoxSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
  430. return _input
  431. class KeypointSequentialOps(SequentialOpsInterface[Keypoints]):
  432. """Apply and inverse transformations for keypoints tensors.
  433. This is for transform keypoints in the format (B, N, 2).
  434. """
  435. @classmethod
  436. def transform(
  437. cls, input: Keypoints, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  438. ) -> Keypoints:
  439. """Apply a transformation with respect to the parameters.
  440. Args:
  441. input: the input tensor, (B, N, 4, 2) or (B, 4, 2).
  442. module: any torch Module but only kornia augmentation modules will count
  443. to apply transformations.
  444. param: the corresponding parameters to the module.
  445. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  446. """
  447. if extra_args is None:
  448. extra_args = {}
  449. _input = input.clone()
  450. if isinstance(module, (K.GeometricAugmentationBase2D,)):
  451. _input = module.transform_keypoints(
  452. _input,
  453. cls.get_instance_module_param(param),
  454. module.flags,
  455. transform=module.transform_matrix,
  456. **extra_args,
  457. )
  458. elif isinstance(module, (K.GeometricAugmentationBase3D,)):
  459. raise NotImplementedError(
  460. "The support for 3d keypoint operations are not yet supported. "
  461. "You are welcome to file a PR in our repo."
  462. )
  463. elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
  464. _input = module.transform_keypoints(
  465. _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
  466. )
  467. elif isinstance(module, K.container.ImageSequentialBase):
  468. _input = module.transform_keypoints(
  469. _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
  470. )
  471. elif isinstance(module, (K.auto.operations.OperationBase,)):
  472. return KeypointSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)
  473. return _input
  474. @classmethod
  475. def inverse(
  476. cls, input: Keypoints, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
  477. ) -> Keypoints:
  478. """Inverse a transformation with respect to the parameters.
  479. Args:
  480. input: the input tensor.
  481. module: any torch Module but only kornia augmentation modules will count
  482. to apply transformations.
  483. param: the corresponding parameters to the module.
  484. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  485. """
  486. if extra_args is None:
  487. extra_args = {}
  488. _input = input.clone()
  489. if isinstance(module, (K.GeometricAugmentationBase2D,)):
  490. if module.transform_matrix is None:
  491. raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
  492. transform = module.compute_inverse_transformation(module.transform_matrix)
  493. _input = module.inverse_keypoints(
  494. _input, cls.get_instance_module_param(param), module.flags, transform=transform, **extra_args
  495. )
  496. elif isinstance(module, (K.GeometricAugmentationBase3D,)):
  497. raise NotImplementedError(
  498. "The support for 3d keypoint operations are not yet supported. "
  499. "You are welcome to file a PR in our repo."
  500. )
  501. elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
  502. _input = module.inverse_keypoints(
  503. _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
  504. )
  505. elif isinstance(module, K.container.ImageSequentialBase):
  506. _input = module.inverse_keypoints(
  507. _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
  508. )
  509. elif isinstance(module, (K.auto.operations.OperationBase,)):
  510. return KeypointSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
  511. return _input