base.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, Optional
  18. from torch import float16, float32, float64
  19. import kornia
  20. from kornia.augmentation.base import _AugmentationBase
  21. from kornia.augmentation.utils import _transform_input3d, _transform_input3d_by_shape, _validate_input_dtype
  22. from kornia.core import Tensor
  23. from kornia.geometry.boxes import Boxes3D
  24. from kornia.geometry.keypoints import Keypoints3D
  25. class AugmentationBase3D(_AugmentationBase):
  26. r"""AugmentationBase3D base class for customized augmentation implementations.
  27. Args:
  28. p: probability for applying an augmentation. This param controls the augmentation probabilities
  29. element-wise for a batch.
  30. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
  31. probabilities batch-wise.
  32. same_on_batch: apply the same transformation across the batch.
  33. """
  34. def validate_tensor(self, input: Tensor) -> None:
  35. """Check if the input tensor is formatted as expected."""
  36. _validate_input_dtype(input, accepted_dtypes=[float16, float32, float64])
  37. if len(input.shape) != 5:
  38. raise RuntimeError(f"Expect (B, C, D, H, W). Got {input.shape}.")
  39. def transform_tensor(self, input: Tensor, *, shape: Optional[Tensor] = None, match_channel: bool = True) -> Tensor:
  40. """Convert any incoming (D, H, W), (C, D, H, W) and (B, C, D, H, W) into (B, C, D, H, W)."""
  41. _validate_input_dtype(input, accepted_dtypes=[float16, float32, float64])
  42. if shape is None:
  43. return _transform_input3d(input)
  44. else:
  45. return _transform_input3d_by_shape(input, reference_shape=shape, match_channel=match_channel)
  46. def identity_matrix(self, input: Tensor) -> Tensor:
  47. """Return 4x4 identity matrix."""
  48. return kornia.eye_like(4, input)
  49. class RigidAffineAugmentationBase3D(AugmentationBase3D):
  50. r"""AugmentationBase2D base class for rigid/affine augmentation implementations.
  51. RigidAffineAugmentationBase2D enables routined transformation with given transformation matrices
  52. for different data types like masks, boxes, and keypoints.
  53. Args:
  54. p: probability for applying an augmentation. This param controls the augmentation probabilities
  55. element-wise for a batch.
  56. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
  57. probabilities batch-wise.
  58. same_on_batch: apply the same transformation across the batch.
  59. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch
  60. form ``False``.
  61. """
  62. _transform_matrix: Optional[Tensor]
  63. @property
  64. def transform_matrix(self) -> Optional[Tensor]:
  65. return self._transform_matrix
  66. def compute_transformation(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor:
  67. raise NotImplementedError
  68. def generate_transformation_matrix(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor:
  69. """Generate transformation matrices with the given input and param settings."""
  70. batch_prob = params["batch_prob"]
  71. to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions.
  72. in_tensor = self.transform_tensor(input)
  73. if not to_apply.any():
  74. trans_matrix = self.identity_matrix(in_tensor)
  75. elif to_apply.all():
  76. trans_matrix = self.compute_transformation(in_tensor, params=params, flags=flags)
  77. else:
  78. trans_matrix = self.identity_matrix(in_tensor)
  79. trans_matrix = trans_matrix.index_put(
  80. (to_apply,), self.compute_transformation(in_tensor[to_apply], params=params, flags=flags)
  81. )
  82. return trans_matrix
  83. def inverse_inputs(
  84. self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  85. ) -> Tensor:
  86. raise NotImplementedError
  87. def inverse_masks(
  88. self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  89. ) -> Tensor:
  90. raise NotImplementedError
  91. def inverse_boxes(
  92. self, input: Boxes3D, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  93. ) -> Boxes3D:
  94. raise NotImplementedError
  95. def inverse_keypoints(
  96. self, input: Keypoints3D, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  97. ) -> Keypoints3D:
  98. raise NotImplementedError
  99. def inverse_classes(
  100. self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  101. ) -> Tensor:
  102. raise NotImplementedError
  103. def apply_func(
  104. self, in_tensor: Tensor, params: Dict[str, Tensor], flags: Optional[Dict[str, Any]] = None
  105. ) -> Tensor:
  106. if flags is None:
  107. flags = self.flags
  108. trans_matrix = self.generate_transformation_matrix(in_tensor, params, flags)
  109. output = self.transform_inputs(in_tensor, params, flags, trans_matrix)
  110. self._transform_matrix = trans_matrix
  111. return output