base.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
  18. import torch
  19. from kornia.augmentation.auto.operations.base import OperationBase
  20. from kornia.augmentation.auto.operations.policy import PolicySequential
  21. from kornia.augmentation.container.base import ImageSequentialBase, TransformMatrixMinIn
  22. from kornia.augmentation.container.ops import InputSequentialOps
  23. from kornia.augmentation.container.params import ParamItem
  24. from kornia.core import Module, Tensor
  25. from kornia.utils import eye_like
  26. NUMBER = Union[float, int]
  27. OP_CONFIG = Tuple[str, NUMBER, Optional[NUMBER]]
  28. SUBPOLICY_CONFIG = List[OP_CONFIG]
  29. class PolicyAugmentBase(ImageSequentialBase, TransformMatrixMinIn):
  30. """Policy-based image augmentation."""
  31. def __init__(self, policy: List[SUBPOLICY_CONFIG], transformation_matrix_mode: str = "silence") -> None:
  32. policies = self.compose_policy(policy)
  33. super().__init__(*policies)
  34. self._parse_transformation_matrix_mode(transformation_matrix_mode)
  35. self._valid_ops_for_transform_computation: Tuple[Any, ...] = (PolicySequential,)
  36. def _update_transform_matrix_for_valid_op(self, module: PolicySequential) -> None: # type: ignore
  37. self._transform_matrices.append(module.transform_matrix)
  38. def clear_state(self) -> None:
  39. self._reset_transform_matrix_state()
  40. return super().clear_state()
  41. def compose_policy(self, policy: List[SUBPOLICY_CONFIG]) -> List[PolicySequential]:
  42. """Compose policy by the provided policy config."""
  43. return [self.compose_subpolicy_sequential(subpolicy) for subpolicy in policy]
  44. def compose_subpolicy_sequential(self, subpolicy: SUBPOLICY_CONFIG) -> PolicySequential:
  45. raise NotImplementedError
  46. def identity_matrix(self, input: Tensor) -> Tensor:
  47. """Return identity matrix."""
  48. return eye_like(3, input)
  49. def get_transformation_matrix(
  50. self,
  51. input: Tensor,
  52. params: Optional[List[ParamItem]] = None,
  53. recompute: bool = False,
  54. extra_args: Optional[Dict[str, Any]] = None,
  55. ) -> Optional[Tensor]:
  56. """Compute the transformation matrix according to the provided parameters.
  57. Args:
  58. input: the input tensor.
  59. params: params for the sequence.
  60. recompute: if to recompute the transformation matrix according to the params.
  61. default: False.
  62. extra_args: Optional dictionary of extra arguments with specific options for different input types.
  63. """
  64. if params is None:
  65. raise NotImplementedError("requires params to be provided.")
  66. named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence(params)
  67. # Define as 1 for broadcasting
  68. res_mat: Optional[Tensor] = None
  69. for (_, module), param in zip(named_modules, params if params is not None else []):
  70. module = cast(PolicySequential, module)
  71. mat = module.get_transformation_matrix(
  72. input, params=cast(Optional[List[ParamItem]], param.data), recompute=recompute, extra_args=extra_args
  73. )
  74. res_mat = mat if res_mat is None else mat @ res_mat
  75. return res_mat
  76. def is_intensity_only(self, params: Optional[List[ParamItem]] = None) -> bool:
  77. named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence(params)
  78. for _, module in named_modules:
  79. module = cast(PolicySequential, module)
  80. if not module.is_intensity_only():
  81. return False
  82. return True
  83. def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
  84. named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence()
  85. params: List[ParamItem] = []
  86. mod_param: Union[Dict[str, Tensor], List[ParamItem]]
  87. for name, module in named_modules:
  88. module = cast(OperationBase, module)
  89. mod_param = module.forward_parameters(batch_shape)
  90. param = ParamItem(name, mod_param)
  91. params.append(param)
  92. return params
  93. def transform_inputs(
  94. self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
  95. ) -> Tensor:
  96. for param in params:
  97. module = self.get_submodule(param.name)
  98. input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
  99. return input
  100. def forward(
  101. self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
  102. ) -> Tensor:
  103. self.clear_state()
  104. if params is None:
  105. inp = input
  106. _, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
  107. params = self.forward_parameters(out_shape)
  108. for param in params:
  109. module = self.get_submodule(param.name)
  110. input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
  111. self._update_transform_matrix_by_module(module)
  112. self._params = params
  113. return input