| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Any, Dict, Optional
- from torch import float16, float32, float64
- import kornia
- from kornia.augmentation.base import _AugmentationBase
- from kornia.augmentation.utils import _transform_input3d, _transform_input3d_by_shape, _validate_input_dtype
- from kornia.core import Tensor
- from kornia.geometry.boxes import Boxes3D
- from kornia.geometry.keypoints import Keypoints3D
- class AugmentationBase3D(_AugmentationBase):
- r"""AugmentationBase3D base class for customized augmentation implementations.
- Args:
- p: probability for applying an augmentation. This param controls the augmentation probabilities
- element-wise for a batch.
- p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
- probabilities batch-wise.
- same_on_batch: apply the same transformation across the batch.
- """
- def validate_tensor(self, input: Tensor) -> None:
- """Check if the input tensor is formatted as expected."""
- _validate_input_dtype(input, accepted_dtypes=[float16, float32, float64])
- if len(input.shape) != 5:
- raise RuntimeError(f"Expect (B, C, D, H, W). Got {input.shape}.")
- def transform_tensor(self, input: Tensor, *, shape: Optional[Tensor] = None, match_channel: bool = True) -> Tensor:
- """Convert any incoming (D, H, W), (C, D, H, W) and (B, C, D, H, W) into (B, C, D, H, W)."""
- _validate_input_dtype(input, accepted_dtypes=[float16, float32, float64])
- if shape is None:
- return _transform_input3d(input)
- else:
- return _transform_input3d_by_shape(input, reference_shape=shape, match_channel=match_channel)
- def identity_matrix(self, input: Tensor) -> Tensor:
- """Return 4x4 identity matrix."""
- return kornia.eye_like(4, input)
- class RigidAffineAugmentationBase3D(AugmentationBase3D):
- r"""AugmentationBase2D base class for rigid/affine augmentation implementations.
- RigidAffineAugmentationBase2D enables routined transformation with given transformation matrices
- for different data types like masks, boxes, and keypoints.
- Args:
- p: probability for applying an augmentation. This param controls the augmentation probabilities
- element-wise for a batch.
- p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
- probabilities batch-wise.
- same_on_batch: apply the same transformation across the batch.
- keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch
- form ``False``.
- """
- _transform_matrix: Optional[Tensor]
- @property
- def transform_matrix(self) -> Optional[Tensor]:
- return self._transform_matrix
- def compute_transformation(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor:
- raise NotImplementedError
- def generate_transformation_matrix(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor:
- """Generate transformation matrices with the given input and param settings."""
- batch_prob = params["batch_prob"]
- to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions.
- in_tensor = self.transform_tensor(input)
- if not to_apply.any():
- trans_matrix = self.identity_matrix(in_tensor)
- elif to_apply.all():
- trans_matrix = self.compute_transformation(in_tensor, params=params, flags=flags)
- else:
- trans_matrix = self.identity_matrix(in_tensor)
- trans_matrix = trans_matrix.index_put(
- (to_apply,), self.compute_transformation(in_tensor[to_apply], params=params, flags=flags)
- )
- return trans_matrix
- def inverse_inputs(
- self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
- ) -> Tensor:
- raise NotImplementedError
- def inverse_masks(
- self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
- ) -> Tensor:
- raise NotImplementedError
- def inverse_boxes(
- self, input: Boxes3D, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
- ) -> Boxes3D:
- raise NotImplementedError
- def inverse_keypoints(
- self, input: Keypoints3D, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
- ) -> Keypoints3D:
- raise NotImplementedError
- def inverse_classes(
- self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
- ) -> Tensor:
- raise NotImplementedError
- def apply_func(
- self, in_tensor: Tensor, params: Dict[str, Tensor], flags: Optional[Dict[str, Any]] = None
- ) -> Tensor:
- if flags is None:
- flags = self.flags
- trans_matrix = self.generate_transformation_matrix(in_tensor, params, flags)
- output = self.transform_inputs(in_tensor, params, flags, trans_matrix)
- self._transform_matrix = trans_matrix
- return output
|