policy.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
  18. from torch import Size
  19. import kornia.augmentation as K
  20. from kornia.augmentation.auto.operations import OperationBase
  21. from kornia.augmentation.container.base import ImageSequentialBase, TransformMatrixMinIn
  22. from kornia.augmentation.container.ops import InputSequentialOps
  23. from kornia.augmentation.container.params import ParamItem
  24. from kornia.augmentation.utils import _transform_input, override_parameters
  25. from kornia.core import Module, Tensor, as_tensor
  26. from kornia.utils import eye_like
  27. class PolicySequential(TransformMatrixMinIn, ImageSequentialBase):
  28. """Policy tuple for applying multiple operations.
  29. Args:
  30. operations: a list of operations to perform.
  31. """
  32. def __init__(self, *operations: OperationBase) -> None:
  33. self.validate_operations(*operations)
  34. super().__init__(*operations)
  35. self._valid_ops_for_transform_computation: Tuple[Any, ...] = (OperationBase,)
  36. def _update_transform_matrix_for_valid_op(self, module: Module) -> None:
  37. self._transform_matrices.append(module.transform_matrix)
  38. def clear_state(self) -> None:
  39. self._reset_transform_matrix_state()
  40. return super().clear_state()
  41. def validate_operations(self, *operations: OperationBase) -> None:
  42. invalid_ops: List[OperationBase] = []
  43. for op in operations:
  44. if not isinstance(op, OperationBase):
  45. invalid_ops.append(op)
  46. if len(invalid_ops) != 0:
  47. raise ValueError(f"All operations must be Kornia Operations. Got {invalid_ops}.")
  48. def identity_matrix(self, input: Tensor) -> Tensor:
  49. """Return identity matrix."""
  50. return eye_like(3, input)
  51. def get_transformation_matrix(
  52. self,
  53. input: Tensor,
  54. params: Optional[List[ParamItem]] = None,
  55. recompute: bool = False,
  56. extra_args: Optional[Dict[str, Any]] = None,
  57. ) -> Tensor:
  58. """Compute the transformation matrix according to the provided parameters.
  59. Args:
  60. input: the input tensor.
  61. params: params for the sequence.
  62. recompute: if to recompute the transformation matrix according to the params.
  63. default: False.
  64. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  65. """
  66. if params is None:
  67. raise NotImplementedError("requires params to be provided.")
  68. named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence(params)
  69. # Define as 1 for broadcasting
  70. res_mat: Tensor = self.identity_matrix(_transform_input(input))
  71. for (_, module), param in zip(named_modules, params if params is not None else []):
  72. module = cast(OperationBase, module)
  73. if isinstance(module.op, (K.GeometricAugmentationBase2D,)) and isinstance(param.data, dict):
  74. ori_shape = input.shape
  75. input = module.op.transform_tensor(input)
  76. # Standardize shape
  77. if recompute:
  78. flags = override_parameters(module.op.flags, extra_args, in_place=False)
  79. mat = module.op.generate_transformation_matrix(input, param.data, flags)
  80. elif module.op._transform_matrix is not None:
  81. mat = as_tensor(module.transform_matrix, device=input.device, dtype=input.dtype)
  82. else:
  83. raise RuntimeError(f"{module}.transform_matrix is None while `recompute=False`.")
  84. res_mat = mat @ res_mat
  85. input = module.op.transform_output_tensor(input, ori_shape)
  86. if module.op.keepdim and ori_shape != input.shape:
  87. res_mat = res_mat.squeeze()
  88. return res_mat
  89. def is_intensity_only(self) -> bool:
  90. for module in self.children():
  91. module = cast(OperationBase, module)
  92. if isinstance(module.op, (K.GeometricAugmentationBase2D,)):
  93. return False
  94. return True
  95. def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
  96. if params is not None:
  97. return super().get_children_by_params(params)
  98. return self.named_children()
  99. def forward_parameters(self, batch_shape: Size) -> List[ParamItem]:
  100. named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence()
  101. params: List[ParamItem] = []
  102. mod_param: Union[Dict[str, Tensor], List[ParamItem]]
  103. for name, module in named_modules:
  104. module = cast(OperationBase, module)
  105. mod_param = module.op.forward_parameters(batch_shape)
  106. param = ParamItem(name, mod_param)
  107. params.append(param)
  108. return params
  109. def transform_inputs(
  110. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  111. ) -> Tensor:
  112. for param in params:
  113. module = self.get_submodule(param.name)
  114. input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
  115. self._update_transform_matrix_by_module(module)
  116. return input