base.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from enum import Enum
  18. from typing import Any, Callable, Dict, Optional, Tuple, Union
  19. import torch
  20. from torch.distributions import Bernoulli, Distribution, RelaxedBernoulli
  21. from kornia.augmentation.random_generator import RandomGeneratorBase
  22. from kornia.augmentation.utils import (
  23. _adapted_rsampling,
  24. _adapted_sampling,
  25. _transform_output_shape,
  26. override_parameters,
  27. )
  28. from kornia.core import ImageModule as Module
  29. from kornia.core import Tensor, tensor, zeros
  30. from kornia.geometry.boxes import Boxes
  31. from kornia.geometry.keypoints import Keypoints
  32. from kornia.utils import is_autocast_enabled
  33. TensorWithTransformMat = Union[Tensor, Tuple[Tensor, Tensor]]
  34. # Trick mypy into not applying contravariance rules to inputs by defining
  35. # forward as a value, rather than a function. See also
  36. # https://github.com/python/mypy/issues/8795
  37. # Based on the trick that torch.nn.Module does for the forward method
  38. def _apply_transform_unimplemented(self: Module, *input: Any) -> Tensor:
  39. r"""Define the computation performed at every call.
  40. Should be overridden by all subclasses.
  41. """
  42. raise NotImplementedError(f'Module [{type(self).__name__}] is missing the required "apply_tranform" function')
  43. class _BasicAugmentationBase(Module):
  44. r"""_BasicAugmentationBase base class for customized augmentation implementations.
  45. Plain augmentation base class without the functionality of transformation matrix calculations.
  46. By default, the random computations will be happened on CPU with ``torch.get_default_dtype()``.
  47. To change this behaviour, please use ``set_rng_device_and_dtype``.
  48. For automatically generating the corresponding ``__repr__`` with full customized parameters, you may need to
  49. implement ``_param_generator`` by inheriting ``RandomGeneratorBase`` for generating random parameters and
  50. put all static parameters inside ``self.flags``. You may take the advantage of ``PlainUniformGenerator`` to
  51. generate simple uniform parameters with less boilerplate code.
  52. Args:
  53. p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise.
  54. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
  55. probabilities batch-wise.
  56. same_on_batch: apply the same transformation across the batch.
  57. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to
  58. the batch form ``False``.
  59. """
  60. # TODO: Hard to support. Many codes are not ONNX-friendly that contains lots of if-else blocks, etc.
  61. # Please contribute if anyone interested.
  62. ONNX_EXPORTABLE = False
  63. def __init__(
  64. self,
  65. p: float = 0.5,
  66. p_batch: float = 1.0,
  67. same_on_batch: bool = False,
  68. keepdim: bool = False,
  69. ) -> None:
  70. super().__init__()
  71. self.p = p
  72. self.p_batch = p_batch
  73. self.same_on_batch = same_on_batch
  74. self.keepdim = keepdim
  75. self._params: Dict[str, Tensor] = {}
  76. self._p_gen: Distribution
  77. self._p_batch_gen: Distribution
  78. if p != 0.0 or p != 1.0:
  79. self._p_gen = Bernoulli(self.p)
  80. if p_batch != 0.0 or p_batch != 1.0:
  81. self._p_batch_gen = Bernoulli(self.p_batch)
  82. self._param_generator: Optional[RandomGeneratorBase] = None
  83. self.flags: Dict[str, Any] = {}
  84. self.set_rng_device_and_dtype(torch.device("cpu"), torch.get_default_dtype())
  85. apply_transform: Callable[..., Tensor] = _apply_transform_unimplemented
  86. def to(self, *args: Any, **kwargs: Any) -> "_BasicAugmentationBase":
  87. r"""Set the device and dtype for the random number generator."""
  88. device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
  89. self.set_rng_device_and_dtype(device, dtype)
  90. return super().to(*args, **kwargs)
  91. def __repr__(self) -> str:
  92. txt = f"p={self.p}, p_batch={self.p_batch}, same_on_batch={self.same_on_batch}"
  93. if isinstance(self._param_generator, RandomGeneratorBase):
  94. txt = f"{self._param_generator!s}, {txt}"
  95. for k, v in self.flags.items():
  96. if isinstance(v, Enum):
  97. txt += f", {k}={v.name.lower()}"
  98. else:
  99. txt += f", {k}={v}"
  100. return f"{self.__class__.__name__}({txt})"
  101. def __unpack_input__(self, input: Tensor) -> Tensor:
  102. return input
  103. def transform_tensor(
  104. self,
  105. input: Tensor,
  106. *,
  107. shape: Optional[Tensor] = None,
  108. match_channel: bool = True,
  109. ) -> Tensor:
  110. """Standardize input tensors."""
  111. raise NotImplementedError
  112. def validate_tensor(self, input: Tensor) -> None:
  113. """Check if the input tensor is formatted as expected."""
  114. raise NotImplementedError
  115. def transform_output_tensor(self, output: Tensor, output_shape: Tuple[int, ...]) -> Tensor:
  116. """Standardize output tensors."""
  117. return _transform_output_shape(output, output_shape) if self.keepdim else output
  118. def generate_parameters(self, batch_shape: Tuple[int, ...]) -> Dict[str, Tensor]:
  119. if self._param_generator is not None:
  120. return self._param_generator(batch_shape, self.same_on_batch)
  121. return {}
  122. def set_rng_device_and_dtype(self, device: torch.device, dtype: torch.dtype) -> None:
  123. """Change the random generation device and dtype.
  124. Note:
  125. The generated random numbers are not reproducible across different devices and dtypes.
  126. """
  127. self.device = device
  128. self.dtype = dtype
  129. if self._param_generator is not None:
  130. self._param_generator.set_rng_device_and_dtype(device, dtype)
  131. def __batch_prob_generator__(
  132. self,
  133. batch_shape: Tuple[int, ...],
  134. p: float,
  135. p_batch: float,
  136. same_on_batch: bool,
  137. ) -> Tensor:
  138. batch_prob: Tensor
  139. if p_batch == 1:
  140. batch_prob = zeros(1) + 1
  141. elif p_batch == 0:
  142. batch_prob = zeros(1)
  143. elif isinstance(self._p_batch_gen, (RelaxedBernoulli,)):
  144. # NOTE: there is no simple way to know if the sampler has `rsample` or not
  145. batch_prob = _adapted_rsampling((1,), self._p_batch_gen, same_on_batch)
  146. else:
  147. batch_prob = _adapted_sampling((1,), self._p_batch_gen, same_on_batch)
  148. if batch_prob.sum() == 1:
  149. elem_prob: Tensor
  150. if p == 1:
  151. elem_prob = zeros(batch_shape[0]) + 1
  152. elif p == 0:
  153. elem_prob = zeros(batch_shape[0])
  154. elif isinstance(self._p_gen, (RelaxedBernoulli,)):
  155. elem_prob = _adapted_rsampling((batch_shape[0],), self._p_gen, same_on_batch)
  156. else:
  157. elem_prob = _adapted_sampling((batch_shape[0],), self._p_gen, same_on_batch)
  158. batch_prob = batch_prob * elem_prob
  159. else:
  160. batch_prob = batch_prob.repeat(batch_shape[0])
  161. if len(batch_prob.shape) == 2:
  162. return batch_prob[..., 0]
  163. return batch_prob
  164. def _process_kwargs_to_params_and_flags(
  165. self,
  166. params: Optional[Dict[str, Tensor]] = None,
  167. flags: Optional[Dict[str, Any]] = None,
  168. **kwargs: Any,
  169. ) -> Tuple[Dict[str, Tensor], Dict[str, Any]]:
  170. # NOTE: determine how to save self._params
  171. save_kwargs = kwargs["save_kwargs"] if "save_kwargs" in kwargs else False
  172. params = self._params if params is None else params
  173. flags = self.flags if flags is None else flags
  174. if save_kwargs:
  175. params = override_parameters(params, kwargs, in_place=True)
  176. self._params = params
  177. else:
  178. self._params = params
  179. params = override_parameters(params, kwargs, in_place=False)
  180. flags = override_parameters(flags, kwargs, in_place=False)
  181. return params, flags
  182. def forward_parameters(self, batch_shape: Tuple[int, ...]) -> Dict[str, Tensor]:
  183. batch_prob = self.__batch_prob_generator__(batch_shape, self.p, self.p_batch, self.same_on_batch)
  184. to_apply = batch_prob > 0.5
  185. _params = self.generate_parameters(torch.Size((int(to_apply.sum().item()), *batch_shape[1:])))
  186. if _params is None:
  187. _params = {}
  188. _params["batch_prob"] = batch_prob
  189. # Added another input_size parameter for geometric transformations
  190. # This might be needed for correctly inversing.
  191. input_size = tensor(batch_shape, dtype=torch.long)
  192. _params.update({"forward_input_shape": input_size})
  193. return _params
  194. def apply_func(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor:
  195. return self.apply_transform(input, params, flags)
  196. def forward(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any) -> Tensor:
  197. """Perform forward operations.
  198. Args:
  199. input: the input tensor.
  200. params: the corresponding parameters for an operation.
  201. If None, a new parameter suite will be generated.
  202. **kwargs: key-value pairs to override the parameters and flags.
  203. Note:
  204. By default, all the overwriting parameters in kwargs will not be recorded
  205. as in ``self._params``. If you wish it to be recorded, you may pass
  206. ``save_kwargs=True`` additionally.
  207. """
  208. in_tensor = self.__unpack_input__(input)
  209. input_shape = in_tensor.shape
  210. in_tensor = self.transform_tensor(in_tensor)
  211. batch_shape = in_tensor.shape
  212. if params is None:
  213. params = self.forward_parameters(batch_shape)
  214. if "batch_prob" not in params:
  215. params["batch_prob"] = tensor([True] * batch_shape[0])
  216. params, flags = self._process_kwargs_to_params_and_flags(params, self.flags, **kwargs)
  217. output = self.apply_func(in_tensor, params, flags)
  218. return self.transform_output_tensor(output, input_shape) if self.keepdim else output
  219. class _AugmentationBase(_BasicAugmentationBase):
  220. r"""_AugmentationBase base class for customized augmentation implementations.
  221. Advanced augmentation base class with the functionality of transformation matrix calculations.
  222. Args:
  223. p: probability for applying an augmentation. This param controls the augmentation probabilities
  224. element-wise for a batch.
  225. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
  226. probabilities batch-wise.
  227. same_on_batch: apply the same transformation across the batch.
  228. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
  229. to the batch form ``False``.
  230. """
  231. def apply_transform(
  232. self,
  233. input: Tensor,
  234. params: Dict[str, Tensor],
  235. flags: Dict[str, Any],
  236. transform: Optional[Tensor] = None,
  237. ) -> Tensor:
  238. # apply transform for the input image tensor
  239. raise NotImplementedError
  240. def apply_non_transform(
  241. self,
  242. input: Tensor,
  243. params: Dict[str, Tensor],
  244. flags: Dict[str, Any],
  245. transform: Optional[Tensor] = None,
  246. ) -> Tensor:
  247. # apply additional transform for the images that are skipped from transformation
  248. # where batch_prob == False.
  249. return input
  250. def transform_inputs(
  251. self,
  252. input: Tensor,
  253. params: Dict[str, Tensor],
  254. flags: Dict[str, Any],
  255. transform: Optional[Tensor] = None,
  256. **kwargs: Any,
  257. ) -> Tensor:
  258. params, flags = self._process_kwargs_to_params_and_flags(
  259. self._params if params is None else params, flags, **kwargs
  260. )
  261. batch_prob = params["batch_prob"]
  262. to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions.
  263. ori_shape = input.shape
  264. in_tensor = self.transform_tensor(input)
  265. self.validate_tensor(in_tensor)
  266. if to_apply.all():
  267. output = self.apply_transform(in_tensor, params, flags, transform=transform)
  268. elif not to_apply.any():
  269. output = self.apply_non_transform(in_tensor, params, flags, transform=transform)
  270. else: # If any tensor needs to be transformed.
  271. output = self.apply_non_transform(in_tensor, params, flags, transform=transform)
  272. applied = self.apply_transform(
  273. in_tensor[to_apply],
  274. params,
  275. flags,
  276. transform=transform if transform is None else transform[to_apply],
  277. )
  278. if is_autocast_enabled():
  279. output = output.type(input.dtype)
  280. applied = applied.type(input.dtype)
  281. output = output.index_put((to_apply,), applied)
  282. output = _transform_output_shape(output, ori_shape) if self.keepdim else output
  283. if is_autocast_enabled():
  284. output = output.type(input.dtype)
  285. return output
  286. def transform_masks(
  287. self,
  288. input: Tensor,
  289. params: Dict[str, Tensor],
  290. flags: Dict[str, Any],
  291. transform: Optional[Tensor] = None,
  292. **kwargs: Any,
  293. ) -> Tensor:
  294. params, flags = self._process_kwargs_to_params_and_flags(
  295. self._params if params is None else params, flags, **kwargs
  296. )
  297. batch_prob = params["batch_prob"]
  298. to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions.
  299. ori_shape = input.shape
  300. shape = params["forward_input_shape"]
  301. in_tensor = self.transform_tensor(input, shape=shape, match_channel=False)
  302. self.validate_tensor(in_tensor)
  303. if to_apply.all():
  304. output = self.apply_transform_mask(in_tensor, params, flags, transform=transform)
  305. elif not to_apply.any():
  306. output = self.apply_non_transform_mask(in_tensor, params, flags, transform=transform)
  307. else: # If any tensor needs to be transformed.
  308. output = self.apply_non_transform_mask(in_tensor, params, flags, transform=transform)
  309. applied = self.apply_transform_mask(
  310. in_tensor[to_apply],
  311. params,
  312. flags,
  313. transform=transform if transform is None else transform[to_apply],
  314. )
  315. output = output.index_put((to_apply,), applied)
  316. output = _transform_output_shape(output, ori_shape, reference_shape=shape) if self.keepdim else output
  317. return output
  318. def transform_boxes(
  319. self,
  320. input: Boxes,
  321. params: Dict[str, Tensor],
  322. flags: Dict[str, Any],
  323. transform: Optional[Tensor] = None,
  324. **kwargs: Any,
  325. ) -> Boxes:
  326. if not isinstance(input, Boxes):
  327. raise RuntimeError(f"Only `Boxes` is supported. Got {type(input)}.")
  328. params, flags = self._process_kwargs_to_params_and_flags(
  329. self._params if params is None else params, flags, **kwargs
  330. )
  331. batch_prob = params["batch_prob"]
  332. to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions.
  333. output: Boxes
  334. if to_apply.bool().all():
  335. output = self.apply_transform_box(input, params, flags, transform=transform)
  336. elif not to_apply.any():
  337. output = self.apply_non_transform_box(input, params, flags, transform=transform)
  338. else: # If any tensor needs to be transformed.
  339. output = self.apply_non_transform_box(input, params, flags, transform=transform)
  340. applied = self.apply_transform_box(
  341. input[to_apply],
  342. params,
  343. flags,
  344. transform=transform if transform is None else transform[to_apply],
  345. )
  346. if is_autocast_enabled():
  347. output = output.type(input.dtype)
  348. applied = applied.type(input.dtype)
  349. output = output.index_put((to_apply,), applied)
  350. return output
  351. def transform_keypoints(
  352. self,
  353. input: Keypoints,
  354. params: Dict[str, Tensor],
  355. flags: Dict[str, Any],
  356. transform: Optional[Tensor] = None,
  357. **kwargs: Any,
  358. ) -> Keypoints:
  359. if not isinstance(input, Keypoints):
  360. raise RuntimeError(f"Only `Keypoints` is supported. Got {type(input)}.")
  361. params, flags = self._process_kwargs_to_params_and_flags(
  362. self._params if params is None else params, flags, **kwargs
  363. )
  364. batch_prob = params["batch_prob"]
  365. to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions.
  366. if to_apply.all():
  367. output = self.apply_transform_keypoint(input, params, flags, transform=transform)
  368. elif not to_apply.any():
  369. output = self.apply_non_transform_keypoint(input, params, flags, transform=transform)
  370. else: # If any tensor needs to be transformed.
  371. output = self.apply_non_transform_keypoint(input, params, flags, transform=transform)
  372. applied = self.apply_transform_keypoint(
  373. input[to_apply],
  374. params,
  375. flags,
  376. transform=transform if transform is None else transform[to_apply],
  377. )
  378. if is_autocast_enabled():
  379. output = output.type(input.dtype)
  380. applied = applied.type(input.dtype)
  381. output = output.index_put((to_apply,), applied)
  382. return output
  383. def transform_classes(
  384. self,
  385. input: Tensor,
  386. params: Dict[str, Tensor],
  387. flags: Dict[str, Any],
  388. transform: Optional[Tensor] = None,
  389. **kwargs: Any,
  390. ) -> Tensor:
  391. params, flags = self._process_kwargs_to_params_and_flags(
  392. self._params if params is None else params, flags, **kwargs
  393. )
  394. batch_prob = params["batch_prob"]
  395. to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions.
  396. if to_apply.all():
  397. output = self.apply_transform_class(input, params, flags, transform=transform)
  398. elif not to_apply.any():
  399. output = self.apply_non_transform_class(input, params, flags, transform=transform)
  400. else: # If any tensor needs to be transformed.
  401. output = self.apply_non_transform_class(input, params, flags, transform=transform)
  402. applied = self.apply_transform_class(
  403. input[to_apply],
  404. params,
  405. flags,
  406. transform=transform if transform is None else transform[to_apply],
  407. )
  408. output = output.index_put((to_apply,), applied)
  409. return output
  410. def apply_non_transform_mask(
  411. self,
  412. input: Tensor,
  413. params: Dict[str, Tensor],
  414. flags: Dict[str, Any],
  415. transform: Optional[Tensor] = None,
  416. ) -> Tensor:
  417. """Process masks corresponding to the inputs that are no transformation applied."""
  418. raise NotImplementedError
  419. def apply_transform_mask(
  420. self,
  421. input: Tensor,
  422. params: Dict[str, Tensor],
  423. flags: Dict[str, Any],
  424. transform: Optional[Tensor] = None,
  425. ) -> Tensor:
  426. """Process masks corresponding to the inputs that are transformed."""
  427. raise NotImplementedError
  428. def apply_non_transform_box(
  429. self,
  430. input: Boxes,
  431. params: Dict[str, Tensor],
  432. flags: Dict[str, Any],
  433. transform: Optional[Tensor] = None,
  434. ) -> Boxes:
  435. """Process boxes corresponding to the inputs that are no transformation applied."""
  436. return input
  437. def apply_transform_box(
  438. self,
  439. input: Boxes,
  440. params: Dict[str, Tensor],
  441. flags: Dict[str, Any],
  442. transform: Optional[Tensor] = None,
  443. ) -> Boxes:
  444. """Process boxes corresponding to the inputs that are transformed."""
  445. raise NotImplementedError
  446. def apply_non_transform_keypoint(
  447. self,
  448. input: Keypoints,
  449. params: Dict[str, Tensor],
  450. flags: Dict[str, Any],
  451. transform: Optional[Tensor] = None,
  452. ) -> Keypoints:
  453. """Process keypoints corresponding to the inputs that are no transformation applied."""
  454. return input
  455. def apply_transform_keypoint(
  456. self,
  457. input: Keypoints,
  458. params: Dict[str, Tensor],
  459. flags: Dict[str, Any],
  460. transform: Optional[Tensor] = None,
  461. ) -> Keypoints:
  462. """Process keypoints corresponding to the inputs that are transformed."""
  463. raise NotImplementedError
  464. def apply_non_transform_class(
  465. self,
  466. input: Tensor,
  467. params: Dict[str, Tensor],
  468. flags: Dict[str, Any],
  469. transform: Optional[Tensor] = None,
  470. ) -> Tensor:
  471. """Process class tags corresponding to the inputs that are no transformation applied."""
  472. return input
  473. def apply_transform_class(
  474. self,
  475. input: Tensor,
  476. params: Dict[str, Tensor],
  477. flags: Dict[str, Any],
  478. transform: Optional[Tensor] = None,
  479. ) -> Tensor:
  480. """Process class tags corresponding to the inputs that are transformed."""
  481. raise NotImplementedError
  482. def apply_func(
  483. self,
  484. in_tensor: Tensor,
  485. params: Dict[str, Tensor],
  486. flags: Optional[Dict[str, Any]] = None,
  487. ) -> Tensor:
  488. if flags is None:
  489. flags = self.flags
  490. output = self.transform_inputs(in_tensor, params, flags)
  491. return output