| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
- from torch import Size
- import kornia.augmentation as K
- from kornia.augmentation.auto.operations import OperationBase
- from kornia.augmentation.container.base import ImageSequentialBase, TransformMatrixMinIn
- from kornia.augmentation.container.ops import InputSequentialOps
- from kornia.augmentation.container.params import ParamItem
- from kornia.augmentation.utils import _transform_input, override_parameters
- from kornia.core import Module, Tensor, as_tensor
- from kornia.utils import eye_like
- class PolicySequential(TransformMatrixMinIn, ImageSequentialBase):
- """Policy tuple for applying multiple operations.
- Args:
- operations: a list of operations to perform.
- """
- def __init__(self, *operations: OperationBase) -> None:
- self.validate_operations(*operations)
- super().__init__(*operations)
- self._valid_ops_for_transform_computation: Tuple[Any, ...] = (OperationBase,)
- def _update_transform_matrix_for_valid_op(self, module: Module) -> None:
- self._transform_matrices.append(module.transform_matrix)
- def clear_state(self) -> None:
- self._reset_transform_matrix_state()
- return super().clear_state()
- def validate_operations(self, *operations: OperationBase) -> None:
- invalid_ops: List[OperationBase] = []
- for op in operations:
- if not isinstance(op, OperationBase):
- invalid_ops.append(op)
- if len(invalid_ops) != 0:
- raise ValueError(f"All operations must be Kornia Operations. Got {invalid_ops}.")
- def identity_matrix(self, input: Tensor) -> Tensor:
- """Return identity matrix."""
- return eye_like(3, input)
- def get_transformation_matrix(
- self,
- input: Tensor,
- params: Optional[List[ParamItem]] = None,
- recompute: bool = False,
- extra_args: Optional[Dict[str, Any]] = None,
- ) -> Tensor:
- """Compute the transformation matrix according to the provided parameters.
- Args:
- input: the input tensor.
- params: params for the sequence.
- recompute: if to recompute the transformation matrix according to the params.
- default: False.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- if params is None:
- raise NotImplementedError("requires params to be provided.")
- named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence(params)
- # Define as 1 for broadcasting
- res_mat: Tensor = self.identity_matrix(_transform_input(input))
- for (_, module), param in zip(named_modules, params if params is not None else []):
- module = cast(OperationBase, module)
- if isinstance(module.op, (K.GeometricAugmentationBase2D,)) and isinstance(param.data, dict):
- ori_shape = input.shape
- input = module.op.transform_tensor(input)
- # Standardize shape
- if recompute:
- flags = override_parameters(module.op.flags, extra_args, in_place=False)
- mat = module.op.generate_transformation_matrix(input, param.data, flags)
- elif module.op._transform_matrix is not None:
- mat = as_tensor(module.transform_matrix, device=input.device, dtype=input.dtype)
- else:
- raise RuntimeError(f"{module}.transform_matrix is None while `recompute=False`.")
- res_mat = mat @ res_mat
- input = module.op.transform_output_tensor(input, ori_shape)
- if module.op.keepdim and ori_shape != input.shape:
- res_mat = res_mat.squeeze()
- return res_mat
- def is_intensity_only(self) -> bool:
- for module in self.children():
- module = cast(OperationBase, module)
- if isinstance(module.op, (K.GeometricAugmentationBase2D,)):
- return False
- return True
- def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
- if params is not None:
- return super().get_children_by_params(params)
- return self.named_children()
- def forward_parameters(self, batch_shape: Size) -> List[ParamItem]:
- named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence()
- params: List[ParamItem] = []
- mod_param: Union[Dict[str, Tensor], List[ParamItem]]
- for name, module in named_modules:
- module = cast(OperationBase, module)
- mod_param = module.op.forward_parameters(batch_shape)
- param = ParamItem(name, mod_param)
- params.append(param)
- return params
- def transform_inputs(
- self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
- ) -> Tensor:
- for param in params:
- module = self.get_submodule(param.name)
- input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
- self._update_transform_matrix_by_module(module)
- return input
|