# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from enum import Enum from typing import Any, Callable, Dict, Optional, Tuple, Union import torch from torch.distributions import Bernoulli, Distribution, RelaxedBernoulli from kornia.augmentation.random_generator import RandomGeneratorBase from kornia.augmentation.utils import ( _adapted_rsampling, _adapted_sampling, _transform_output_shape, override_parameters, ) from kornia.core import ImageModule as Module from kornia.core import Tensor, tensor, zeros from kornia.geometry.boxes import Boxes from kornia.geometry.keypoints import Keypoints from kornia.utils import is_autocast_enabled TensorWithTransformMat = Union[Tensor, Tuple[Tensor, Tensor]] # Trick mypy into not applying contravariance rules to inputs by defining # forward as a value, rather than a function. See also # https://github.com/python/mypy/issues/8795 # Based on the trick that torch.nn.Module does for the forward method def _apply_transform_unimplemented(self: Module, *input: Any) -> Tensor: r"""Define the computation performed at every call. Should be overridden by all subclasses. """ raise NotImplementedError(f'Module [{type(self).__name__}] is missing the required "apply_tranform" function') class _BasicAugmentationBase(Module): r"""_BasicAugmentationBase base class for customized augmentation implementations. Plain augmentation base class without the functionality of transformation matrix calculations. By default, the random computations will be happened on CPU with ``torch.get_default_dtype()``. To change this behaviour, please use ``set_rng_device_and_dtype``. For automatically generating the corresponding ``__repr__`` with full customized parameters, you may need to implement ``_param_generator`` by inheriting ``RandomGeneratorBase`` for generating random parameters and put all static parameters inside ``self.flags``. You may take the advantage of ``PlainUniformGenerator`` to generate simple uniform parameters with less boilerplate code. Args: p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. """ # TODO: Hard to support. Many codes are not ONNX-friendly that contains lots of if-else blocks, etc. # Please contribute if anyone interested. ONNX_EXPORTABLE = False def __init__( self, p: float = 0.5, p_batch: float = 1.0, same_on_batch: bool = False, keepdim: bool = False, ) -> None: super().__init__() self.p = p self.p_batch = p_batch self.same_on_batch = same_on_batch self.keepdim = keepdim self._params: Dict[str, Tensor] = {} self._p_gen: Distribution self._p_batch_gen: Distribution if p != 0.0 or p != 1.0: self._p_gen = Bernoulli(self.p) if p_batch != 0.0 or p_batch != 1.0: self._p_batch_gen = Bernoulli(self.p_batch) self._param_generator: Optional[RandomGeneratorBase] = None self.flags: Dict[str, Any] = {} self.set_rng_device_and_dtype(torch.device("cpu"), torch.get_default_dtype()) apply_transform: Callable[..., Tensor] = _apply_transform_unimplemented def to(self, *args: Any, **kwargs: Any) -> "_BasicAugmentationBase": r"""Set the device and dtype for the random number generator.""" device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs) self.set_rng_device_and_dtype(device, dtype) return super().to(*args, **kwargs) def __repr__(self) -> str: txt = f"p={self.p}, p_batch={self.p_batch}, same_on_batch={self.same_on_batch}" if isinstance(self._param_generator, RandomGeneratorBase): txt = f"{self._param_generator!s}, {txt}" for k, v in self.flags.items(): if isinstance(v, Enum): txt += f", {k}={v.name.lower()}" else: txt += f", {k}={v}" return f"{self.__class__.__name__}({txt})" def __unpack_input__(self, input: Tensor) -> Tensor: return input def transform_tensor( self, input: Tensor, *, shape: Optional[Tensor] = None, match_channel: bool = True, ) -> Tensor: """Standardize input tensors.""" raise NotImplementedError def validate_tensor(self, input: Tensor) -> None: """Check if the input tensor is formatted as expected.""" raise NotImplementedError def transform_output_tensor(self, output: Tensor, output_shape: Tuple[int, ...]) -> Tensor: """Standardize output tensors.""" return _transform_output_shape(output, output_shape) if self.keepdim else output def generate_parameters(self, batch_shape: Tuple[int, ...]) -> Dict[str, Tensor]: if self._param_generator is not None: return self._param_generator(batch_shape, self.same_on_batch) return {} def set_rng_device_and_dtype(self, device: torch.device, dtype: torch.dtype) -> None: """Change the random generation device and dtype. Note: The generated random numbers are not reproducible across different devices and dtypes. """ self.device = device self.dtype = dtype if self._param_generator is not None: self._param_generator.set_rng_device_and_dtype(device, dtype) def __batch_prob_generator__( self, batch_shape: Tuple[int, ...], p: float, p_batch: float, same_on_batch: bool, ) -> Tensor: batch_prob: Tensor if p_batch == 1: batch_prob = zeros(1) + 1 elif p_batch == 0: batch_prob = zeros(1) elif isinstance(self._p_batch_gen, (RelaxedBernoulli,)): # NOTE: there is no simple way to know if the sampler has `rsample` or not batch_prob = _adapted_rsampling((1,), self._p_batch_gen, same_on_batch) else: batch_prob = _adapted_sampling((1,), self._p_batch_gen, same_on_batch) if batch_prob.sum() == 1: elem_prob: Tensor if p == 1: elem_prob = zeros(batch_shape[0]) + 1 elif p == 0: elem_prob = zeros(batch_shape[0]) elif isinstance(self._p_gen, (RelaxedBernoulli,)): elem_prob = _adapted_rsampling((batch_shape[0],), self._p_gen, same_on_batch) else: elem_prob = _adapted_sampling((batch_shape[0],), self._p_gen, same_on_batch) batch_prob = batch_prob * elem_prob else: batch_prob = batch_prob.repeat(batch_shape[0]) if len(batch_prob.shape) == 2: return batch_prob[..., 0] return batch_prob def _process_kwargs_to_params_and_flags( self, params: Optional[Dict[str, Tensor]] = None, flags: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: # NOTE: determine how to save self._params save_kwargs = kwargs["save_kwargs"] if "save_kwargs" in kwargs else False params = self._params if params is None else params flags = self.flags if flags is None else flags if save_kwargs: params = override_parameters(params, kwargs, in_place=True) self._params = params else: self._params = params params = override_parameters(params, kwargs, in_place=False) flags = override_parameters(flags, kwargs, in_place=False) return params, flags def forward_parameters(self, batch_shape: Tuple[int, ...]) -> Dict[str, Tensor]: batch_prob = self.__batch_prob_generator__(batch_shape, self.p, self.p_batch, self.same_on_batch) to_apply = batch_prob > 0.5 _params = self.generate_parameters(torch.Size((int(to_apply.sum().item()), *batch_shape[1:]))) if _params is None: _params = {} _params["batch_prob"] = batch_prob # Added another input_size parameter for geometric transformations # This might be needed for correctly inversing. input_size = tensor(batch_shape, dtype=torch.long) _params.update({"forward_input_shape": input_size}) return _params def apply_func(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor: return self.apply_transform(input, params, flags) def forward(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any) -> Tensor: """Perform forward operations. Args: input: the input tensor. params: the corresponding parameters for an operation. If None, a new parameter suite will be generated. **kwargs: key-value pairs to override the parameters and flags. Note: By default, all the overwriting parameters in kwargs will not be recorded as in ``self._params``. If you wish it to be recorded, you may pass ``save_kwargs=True`` additionally. """ in_tensor = self.__unpack_input__(input) input_shape = in_tensor.shape in_tensor = self.transform_tensor(in_tensor) batch_shape = in_tensor.shape if params is None: params = self.forward_parameters(batch_shape) if "batch_prob" not in params: params["batch_prob"] = tensor([True] * batch_shape[0]) params, flags = self._process_kwargs_to_params_and_flags(params, self.flags, **kwargs) output = self.apply_func(in_tensor, params, flags) return self.transform_output_tensor(output, input_shape) if self.keepdim else output class _AugmentationBase(_BasicAugmentationBase): r"""_AugmentationBase base class for customized augmentation implementations. Advanced augmentation base class with the functionality of transformation matrix calculations. Args: p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. """ def apply_transform( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Tensor: # apply transform for the input image tensor raise NotImplementedError def apply_non_transform( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Tensor: # apply additional transform for the images that are skipped from transformation # where batch_prob == False. return input def transform_inputs( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, **kwargs: Any, ) -> Tensor: params, flags = self._process_kwargs_to_params_and_flags( self._params if params is None else params, flags, **kwargs ) batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. ori_shape = input.shape in_tensor = self.transform_tensor(input) self.validate_tensor(in_tensor) if to_apply.all(): output = self.apply_transform(in_tensor, params, flags, transform=transform) elif not to_apply.any(): output = self.apply_non_transform(in_tensor, params, flags, transform=transform) else: # If any tensor needs to be transformed. output = self.apply_non_transform(in_tensor, params, flags, transform=transform) applied = self.apply_transform( in_tensor[to_apply], params, flags, transform=transform if transform is None else transform[to_apply], ) if is_autocast_enabled(): output = output.type(input.dtype) applied = applied.type(input.dtype) output = output.index_put((to_apply,), applied) output = _transform_output_shape(output, ori_shape) if self.keepdim else output if is_autocast_enabled(): output = output.type(input.dtype) return output def transform_masks( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, **kwargs: Any, ) -> Tensor: params, flags = self._process_kwargs_to_params_and_flags( self._params if params is None else params, flags, **kwargs ) batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. ori_shape = input.shape shape = params["forward_input_shape"] in_tensor = self.transform_tensor(input, shape=shape, match_channel=False) self.validate_tensor(in_tensor) if to_apply.all(): output = self.apply_transform_mask(in_tensor, params, flags, transform=transform) elif not to_apply.any(): output = self.apply_non_transform_mask(in_tensor, params, flags, transform=transform) else: # If any tensor needs to be transformed. output = self.apply_non_transform_mask(in_tensor, params, flags, transform=transform) applied = self.apply_transform_mask( in_tensor[to_apply], params, flags, transform=transform if transform is None else transform[to_apply], ) output = output.index_put((to_apply,), applied) output = _transform_output_shape(output, ori_shape, reference_shape=shape) if self.keepdim else output return output def transform_boxes( self, input: Boxes, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, **kwargs: Any, ) -> Boxes: if not isinstance(input, Boxes): raise RuntimeError(f"Only `Boxes` is supported. Got {type(input)}.") params, flags = self._process_kwargs_to_params_and_flags( self._params if params is None else params, flags, **kwargs ) batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. output: Boxes if to_apply.bool().all(): output = self.apply_transform_box(input, params, flags, transform=transform) elif not to_apply.any(): output = self.apply_non_transform_box(input, params, flags, transform=transform) else: # If any tensor needs to be transformed. output = self.apply_non_transform_box(input, params, flags, transform=transform) applied = self.apply_transform_box( input[to_apply], params, flags, transform=transform if transform is None else transform[to_apply], ) if is_autocast_enabled(): output = output.type(input.dtype) applied = applied.type(input.dtype) output = output.index_put((to_apply,), applied) return output def transform_keypoints( self, input: Keypoints, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, **kwargs: Any, ) -> Keypoints: if not isinstance(input, Keypoints): raise RuntimeError(f"Only `Keypoints` is supported. Got {type(input)}.") params, flags = self._process_kwargs_to_params_and_flags( self._params if params is None else params, flags, **kwargs ) batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. if to_apply.all(): output = self.apply_transform_keypoint(input, params, flags, transform=transform) elif not to_apply.any(): output = self.apply_non_transform_keypoint(input, params, flags, transform=transform) else: # If any tensor needs to be transformed. output = self.apply_non_transform_keypoint(input, params, flags, transform=transform) applied = self.apply_transform_keypoint( input[to_apply], params, flags, transform=transform if transform is None else transform[to_apply], ) if is_autocast_enabled(): output = output.type(input.dtype) applied = applied.type(input.dtype) output = output.index_put((to_apply,), applied) return output def transform_classes( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, **kwargs: Any, ) -> Tensor: params, flags = self._process_kwargs_to_params_and_flags( self._params if params is None else params, flags, **kwargs ) batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. if to_apply.all(): output = self.apply_transform_class(input, params, flags, transform=transform) elif not to_apply.any(): output = self.apply_non_transform_class(input, params, flags, transform=transform) else: # If any tensor needs to be transformed. output = self.apply_non_transform_class(input, params, flags, transform=transform) applied = self.apply_transform_class( input[to_apply], params, flags, transform=transform if transform is None else transform[to_apply], ) output = output.index_put((to_apply,), applied) return output def apply_non_transform_mask( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Tensor: """Process masks corresponding to the inputs that are no transformation applied.""" raise NotImplementedError def apply_transform_mask( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Tensor: """Process masks corresponding to the inputs that are transformed.""" raise NotImplementedError def apply_non_transform_box( self, input: Boxes, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Boxes: """Process boxes corresponding to the inputs that are no transformation applied.""" return input def apply_transform_box( self, input: Boxes, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Boxes: """Process boxes corresponding to the inputs that are transformed.""" raise NotImplementedError def apply_non_transform_keypoint( self, input: Keypoints, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Keypoints: """Process keypoints corresponding to the inputs that are no transformation applied.""" return input def apply_transform_keypoint( self, input: Keypoints, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Keypoints: """Process keypoints corresponding to the inputs that are transformed.""" raise NotImplementedError def apply_non_transform_class( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Tensor: """Process class tags corresponding to the inputs that are no transformation applied.""" return input def apply_transform_class( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None, ) -> Tensor: """Process class tags corresponding to the inputs that are transformed.""" raise NotImplementedError def apply_func( self, in_tensor: Tensor, params: Dict[str, Tensor], flags: Optional[Dict[str, Any]] = None, ) -> Tensor: if flags is None: flags = self.flags output = self.transform_inputs(in_tensor, params, flags) return output