base.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, Optional
  18. from torch import float16, float32, float64
  19. from kornia.augmentation.base import _AugmentationBase
  20. from kornia.augmentation.utils import _transform_input, _transform_input_by_shape, _validate_input_dtype
  21. from kornia.core import Tensor
  22. from kornia.geometry.boxes import Boxes
  23. from kornia.geometry.keypoints import Keypoints
  24. from kornia.utils import eye_like, is_autocast_enabled
  25. class AugmentationBase2D(_AugmentationBase):
  26. r"""AugmentationBase2D base class for customized augmentation implementations.
  27. AugmentationBase2D aims at offering a generic base class for a greater level of customization.
  28. If the subclass contains routined matrix-based transformations, `RigidAffineAugmentationBase2D`
  29. might be a better fit.
  30. Args:
  31. p: probability for applying an augmentation. This param controls the augmentation probabilities
  32. element-wise for a batch.
  33. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
  34. probabilities batch-wise.
  35. same_on_batch: apply the same transformation across the batch.
  36. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch
  37. form ``False``.
  38. """
  39. def validate_tensor(self, input: Tensor) -> None:
  40. """Check if the input tensor is formatted as expected."""
  41. _validate_input_dtype(input, accepted_dtypes=[float16, float32, float64])
  42. if len(input.shape) != 4:
  43. raise RuntimeError(f"Expect (B, C, H, W). Got {input.shape}.")
  44. def transform_tensor(self, input: Tensor, *, shape: Optional[Tensor] = None, match_channel: bool = True) -> Tensor:
  45. """Convert any incoming (H, W), (C, H, W) and (B, C, H, W) into (B, C, H, W)."""
  46. _validate_input_dtype(input, accepted_dtypes=[float16, float32, float64])
  47. if shape is None:
  48. return _transform_input(input)
  49. else:
  50. return _transform_input_by_shape(input, reference_shape=shape, match_channel=match_channel)
  51. class RigidAffineAugmentationBase2D(AugmentationBase2D):
  52. r"""AugmentationBase2D base class for rigid/affine augmentation implementations.
  53. RigidAffineAugmentationBase2D enables routined transformation with given transformation matrices
  54. for different data types like masks, boxes, and keypoints.
  55. Args:
  56. p: probability for applying an augmentation. This param controls the augmentation probabilities
  57. element-wise for a batch.
  58. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
  59. probabilities batch-wise.
  60. same_on_batch: apply the same transformation across the batch.
  61. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch
  62. form ``False``.
  63. """
  64. _transform_matrix: Optional[Tensor]
  65. @property
  66. def transform_matrix(self) -> Optional[Tensor]:
  67. return self._transform_matrix
  68. def identity_matrix(self, input: Tensor) -> Tensor:
  69. """Return 3x3 identity matrix."""
  70. return eye_like(3, input)
  71. def compute_transformation(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor:
  72. raise NotImplementedError
  73. def generate_transformation_matrix(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor:
  74. """Generate transformation matrices with the given input and param settings."""
  75. batch_prob = params["batch_prob"]
  76. to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions.
  77. in_tensor = self.transform_tensor(input)
  78. if not to_apply.any():
  79. trans_matrix = self.identity_matrix(in_tensor)
  80. elif to_apply.all():
  81. trans_matrix = self.compute_transformation(in_tensor, params=params, flags=flags)
  82. else:
  83. trans_matrix_A = self.identity_matrix(in_tensor)
  84. trans_matrix_B = self.compute_transformation(in_tensor[to_apply], params=params, flags=flags)
  85. if is_autocast_enabled():
  86. trans_matrix_A = trans_matrix_A.type(input.dtype)
  87. trans_matrix_B = trans_matrix_B.type(input.dtype)
  88. trans_matrix = trans_matrix_A.index_put((to_apply,), trans_matrix_B)
  89. return trans_matrix
  90. def inverse_inputs(
  91. self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  92. ) -> Tensor:
  93. raise NotImplementedError
  94. def inverse_masks(
  95. self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  96. ) -> Tensor:
  97. raise NotImplementedError
  98. def inverse_boxes(
  99. self, input: Boxes, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  100. ) -> Boxes:
  101. raise NotImplementedError
  102. def inverse_keypoints(
  103. self, input: Keypoints, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  104. ) -> Keypoints:
  105. raise NotImplementedError
  106. def inverse_classes(
  107. self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
  108. ) -> Tensor:
  109. raise NotImplementedError
  110. def apply_func(
  111. self, in_tensor: Tensor, params: Dict[str, Tensor], flags: Optional[Dict[str, Any]] = None
  112. ) -> Tensor:
  113. if flags is None:
  114. flags = self.flags
  115. trans_matrix = self.generate_transformation_matrix(in_tensor, params, flags)
  116. output = self.transform_inputs(in_tensor, params, flags, trans_matrix)
  117. self._transform_matrix = trans_matrix
  118. return output