| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import copy
- from abc import ABCMeta, abstractmethod
- from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
- from typing_extensions import ParamSpec
- import kornia.augmentation as K
- from kornia.augmentation.base import _AugmentationBase
- from kornia.constants import DataKey
- from kornia.core import Module, Tensor
- from kornia.geometry.boxes import Boxes
- from kornia.geometry.keypoints import Keypoints
- from .params import ParamItem
- DataType = Union[Tensor, List[Tensor], Boxes, Keypoints]
- # NOTE: shouldn't this SequenceDataType alias be equals to List[DataType]?
- SequenceDataType = Union[List[Tensor], List[List[Tensor]], List[Boxes], List[Keypoints]]
- T = TypeVar("T")
- class SequentialOpsInterface(Generic[T], metaclass=ABCMeta):
- """Abstract interface for applying and inversing transformations."""
- @classmethod
- def get_instance_module_param(cls, param: ParamItem) -> Dict[str, Tensor]:
- if isinstance(param, ParamItem) and isinstance(param.data, dict):
- _params = param.data
- else:
- raise TypeError(f"Expected param (ParamItem.data) be a dictionary. Gotcha {param}.")
- return _params
- @classmethod
- def get_sequential_module_param(cls, param: ParamItem) -> List[ParamItem]:
- if isinstance(param, ParamItem) and isinstance(param.data, list):
- _params = param.data
- else:
- raise TypeError(f"Expected param (ParamItem.data) be a list. Gotcha {param}.")
- return _params
- @classmethod
- @abstractmethod
- def transform(cls, input: T, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None) -> T:
- """Apply a transformation with respect to the parameters.
- Args:
- input: the input tensor.
- module: any torch Module but only kornia augmentation modules will count
- to apply transformations.
- param: the corresponding parameters to the module.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def inverse(cls, input: T, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None) -> T:
- """Inverse a transformation with respect to the parameters.
- Args:
- input: the input tensor.
- module: any torch Module but only kornia augmentation modules will count
- to apply transformations.
- param: the corresponding parameters to the module.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- raise NotImplementedError
- class AugmentationSequentialOps:
- def __init__(self, data_keys: Optional[List[DataKey]]) -> None:
- self._data_keys = data_keys
- @property
- def data_keys(self) -> Optional[List[DataKey]]:
- return self._data_keys
- @data_keys.setter
- def data_keys(self, data_keys: Optional[Union[List[DataKey], List[str], List[int]]]) -> None:
- if data_keys:
- self._data_keys = [DataKey.get(inp) for inp in data_keys]
- else:
- self._data_keys = None
- def preproc_datakeys(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None) -> List[DataKey]:
- if data_keys is None:
- if isinstance(self.data_keys, list):
- return self.data_keys
- raise ValueError("Sequential ops needs data keys to be able to process.")
- else:
- return [DataKey.get(inp) for inp in data_keys]
- def _get_op(self, data_key: DataKey) -> Type[SequentialOpsInterface[Any]]:
- """Return the corresponding operation given a data key."""
- if data_key == DataKey.INPUT:
- return InputSequentialOps
- if data_key == DataKey.MASK:
- return MaskSequentialOps
- if data_key in {DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY}:
- return BoxSequentialOps
- if data_key == DataKey.KEYPOINTS:
- return KeypointSequentialOps
- if data_key == DataKey.CLASS:
- return ClassSequentialOps
- raise RuntimeError(f"Operation for `{data_key.name}` is not found.")
- def transform(
- self,
- *arg: DataType,
- module: Module,
- param: ParamItem,
- extra_args: Dict[DataKey, Dict[str, Any]],
- data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
- ) -> Union[DataType, SequenceDataType]:
- _data_keys = self.preproc_datakeys(data_keys)
- if isinstance(module, K.RandomTransplantation):
- # For transforms which require the full input to calculate the parameters (e.g. RandomTransplantation)
- param = ParamItem(
- name=param.name,
- data=module.params_from_input(
- *arg, # type: ignore[arg-type]
- data_keys=_data_keys,
- params=param.data, # type: ignore[arg-type]
- extra_args=extra_args,
- ),
- )
- outputs = []
- for inp, dcate in zip(arg, _data_keys):
- op = self._get_op(dcate)
- extra_arg = extra_args.get(dcate, {})
- if dcate.name == "MASK" and isinstance(inp, list):
- outputs.append(MaskSequentialOps.transform_list(inp, module, param=param, extra_args=extra_arg))
- else:
- outputs.append(op.transform(inp, module, param=param, extra_args=extra_arg))
- if len(outputs) == 1 and isinstance(outputs, (list, tuple)):
- return outputs[0]
- return outputs
- def inverse(
- self,
- *arg: DataType,
- module: Module,
- param: ParamItem,
- extra_args: Dict[DataKey, Dict[str, Any]],
- data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
- ) -> Union[DataType, SequenceDataType]:
- _data_keys = self.preproc_datakeys(data_keys)
- outputs = []
- for inp, dcate in zip(arg, _data_keys):
- op = self._get_op(dcate)
- extra_arg = extra_args[dcate] if dcate in extra_args else {}
- outputs.append(op.inverse(inp, module, param=param, extra_args=extra_arg))
- if len(outputs) == 1 and isinstance(outputs, (list, tuple)):
- return outputs[0]
- return outputs
- P = ParamSpec("P")
- def make_input_only_sequential(module: "K.container.ImageSequentialBase") -> Callable[P, Tensor]:
- """Disable all other additional inputs (e.g. ) for ImageSequential."""
- def f(*args: P.args, **kwargs: P.kwargs) -> Tensor:
- return module(*args, **kwargs)
- return f
- def get_geometric_only_param(module: "K.container.ImageSequentialBase", param: List[ParamItem]) -> List[ParamItem]:
- """Return geometry param."""
- named_modules = module.get_forward_sequence(param)
- res: List[ParamItem] = []
- for (_, mod), p in zip(named_modules, param):
- if isinstance(mod, (K.GeometricAugmentationBase2D, K.GeometricAugmentationBase3D)):
- res.append(p)
- return res
- class InputSequentialOps(SequentialOpsInterface[Tensor]):
- @classmethod
- def transform(
- cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Tensor:
- if extra_args is None:
- extra_args = {}
- if isinstance(module, (_AugmentationBase, K.MixAugmentationBaseV2)):
- input = module(input, params=cls.get_instance_module_param(param), data_keys=[DataKey.INPUT], **extra_args)
- elif isinstance(module, (K.container.ImageSequentialBase,)):
- input = module.transform_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
- elif isinstance(module, (K.auto.operations.OperationBase,)):
- input = module(input, params=cls.get_instance_module_param(param))
- else:
- if param.data is not None:
- raise AssertionError(f"Non-augmentaion operation {param.name} require empty parameters. Got {param}.")
- input = module(input)
- return input
- @classmethod
- def inverse(
- cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Tensor:
- if extra_args is None:
- extra_args = {}
- if isinstance(module, K.GeometricAugmentationBase2D):
- input = module.inverse(input, params=cls.get_instance_module_param(param), **extra_args)
- elif isinstance(module, (K.GeometricAugmentationBase3D,)):
- raise NotImplementedError(
- "The support for 3d inverse operations are not yet supported. You are welcome to file a PR in our repo."
- )
- elif isinstance(module, (K.auto.operations.OperationBase,)):
- return InputSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
- elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
- input = module.inverse_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
- elif isinstance(module, K.container.ImageSequentialBase):
- input = module.inverse_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
- return input
- class ClassSequentialOps(SequentialOpsInterface[Tensor]):
- """Apply and inverse transformations for class labels if needed."""
- @classmethod
- def transform(
- cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Tensor:
- if isinstance(module, K.MixAugmentationBaseV2):
- raise NotImplementedError(
- "The support for class labels for mix augmentations that change the class label is not yet supported."
- )
- return input
- @classmethod
- def inverse(
- cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Tensor:
- return input
- class MaskSequentialOps(SequentialOpsInterface[Tensor]):
- """Apply and inverse transformations for mask tensors."""
- @classmethod
- def transform(
- cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Tensor:
- """Apply a transformation with respect to the parameters.
- Args:
- input: the input tensor.
- module: any torch Module but only kornia augmentation modules will count
- to apply transformations.
- param: the corresponding parameters to the module.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- if extra_args is None:
- extra_args = {}
- if isinstance(module, (K.GeometricAugmentationBase2D,)):
- input = module.transform_masks(
- input,
- params=cls.get_instance_module_param(param),
- flags=module.flags,
- transform=module.transform_matrix,
- **extra_args,
- )
- elif isinstance(module, (K.GeometricAugmentationBase3D,)):
- raise NotImplementedError(
- "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
- )
- elif isinstance(module, K.RandomTransplantation):
- input = module(input, params=cls.get_instance_module_param(param), data_keys=[DataKey.MASK], **extra_args)
- elif isinstance(module, (_AugmentationBase)):
- input = module.transform_masks(
- input, params=cls.get_instance_module_param(param), flags=module.flags, **extra_args
- )
- elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
- input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
- elif isinstance(module, K.container.ImageSequentialBase):
- input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
- elif isinstance(module, (K.auto.operations.OperationBase,)):
- input = MaskSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)
- return input
- @classmethod
- def transform_list(
- cls, input: List[Tensor], module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> List[Tensor]:
- """Apply a transformation with respect to the parameters.
- Args:
- input: list of input tensors.
- module: any torch Module but only kornia augmentation modules will count
- to apply transformations.
- param: the corresponding parameters to the module.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- if extra_args is None:
- extra_args = {}
- if isinstance(module, (K.GeometricAugmentationBase2D,)):
- tfm_input = []
- params = cls.get_instance_module_param(param)
- params_i = copy.deepcopy(params)
- for i, inp in enumerate(input):
- params_i["batch_prob"] = params["batch_prob"][i]
- tfm_inp = module.transform_masks(
- inp, params=params_i, flags=module.flags, transform=module.transform_matrix, **extra_args
- )
- tfm_input.append(tfm_inp)
- input = tfm_input
- elif isinstance(module, (K.GeometricAugmentationBase3D,)):
- raise NotImplementedError(
- "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
- )
- elif isinstance(module, (_AugmentationBase)):
- tfm_input = []
- params = cls.get_instance_module_param(param)
- params_i = copy.deepcopy(params)
- for i, inp in enumerate(input):
- params_i["batch_prob"] = params["batch_prob"][i]
- tfm_inp = module.transform_masks(inp, params=params_i, flags=module.flags, **extra_args)
- tfm_input.append(tfm_inp)
- input = tfm_input
- elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
- tfm_input = []
- seq_params = cls.get_sequential_module_param(param)
- for inp in input:
- tfm_inp = module.transform_masks(inp, params=seq_params, extra_args=extra_args)
- tfm_input.append(tfm_inp)
- input = tfm_input
- elif isinstance(module, K.container.ImageSequentialBase):
- tfm_input = []
- seq_params = cls.get_sequential_module_param(param)
- for inp in input:
- tfm_inp = module.transform_masks(inp, params=seq_params, extra_args=extra_args)
- tfm_input.append(tfm_inp)
- input = tfm_input
- elif isinstance(module, (K.auto.operations.OperationBase,)):
- raise NotImplementedError(
- "The support for list of masks under auto operations are not yet supported. You are welcome to file a"
- " PR in our repo."
- )
- return input
- @classmethod
- def inverse(
- cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Tensor:
- """Inverse a transformation with respect to the parameters.
- Args:
- input: the input tensor.
- module: any torch Module but only kornia augmentation modules will count
- to apply transformations.
- param: the corresponding parameters to the module.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- if extra_args is None:
- extra_args = {}
- if isinstance(module, (K.GeometricAugmentationBase2D,)):
- if module.transform_matrix is None:
- raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
- transform = module.compute_inverse_transformation(module.transform_matrix)
- input = module.inverse_masks(
- input,
- params=cls.get_instance_module_param(param),
- flags=module.flags,
- transform=transform,
- **extra_args,
- )
- elif isinstance(module, (K.GeometricAugmentationBase3D,)):
- raise NotImplementedError(
- "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
- )
- elif isinstance(module, K.container.ImageSequentialBase):
- input = module.inverse_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
- elif isinstance(module, (K.auto.operations.OperationBase,)):
- input = MaskSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
- return input
- class BoxSequentialOps(SequentialOpsInterface[Boxes]):
- """Apply and inverse transformations for bounding box tensors.
- This is for transform boxes in the format (B, N, 4, 2).
- """
- @classmethod
- def transform(
- cls, input: Boxes, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Boxes:
- """Apply a transformation with respect to the parameters.
- Args:
- input: the input tensor, (B, N, 4, 2) or (B, 4, 2).
- module: any torch Module but only kornia augmentation modules will count
- to apply transformations.
- param: the corresponding parameters to the module.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- if extra_args is None:
- extra_args = {}
- _input = input.clone()
- if isinstance(module, (K.GeometricAugmentationBase2D,)):
- _input = module.transform_boxes(
- _input,
- cls.get_instance_module_param(param),
- module.flags,
- transform=module.transform_matrix,
- **extra_args,
- )
- elif isinstance(module, (K.GeometricAugmentationBase3D,)):
- raise NotImplementedError(
- "The support for 3d box operations are not yet supported. You are welcome to file a PR in our repo."
- )
- elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
- _input = module.transform_boxes(
- _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
- )
- elif isinstance(module, K.container.ImageSequentialBase):
- _input = module.transform_boxes(
- _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
- )
- elif isinstance(module, (K.auto.operations.OperationBase,)):
- return BoxSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)
- return _input
- @classmethod
- def inverse(
- cls, input: Boxes, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Boxes:
- """Inverse a transformation with respect to the parameters.
- Args:
- input: the input tensor.
- module: any torch Module but only kornia augmentation modules will count
- to apply transformations.
- param: the corresponding parameters to the module.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- if extra_args is None:
- extra_args = {}
- _input = input.clone()
- if isinstance(module, (K.GeometricAugmentationBase2D,)):
- if module.transform_matrix is None:
- raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
- transform = module.compute_inverse_transformation(module.transform_matrix)
- _input = module.inverse_boxes(
- _input,
- param.data, # type: ignore[arg-type]
- module.flags,
- transform=transform,
- **extra_args,
- )
- elif isinstance(module, (K.GeometricAugmentationBase3D,)):
- raise NotImplementedError(
- "The support for 3d box operations are not yet supported. You are welcome to file a PR in our repo."
- )
- elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
- _input = module.inverse_boxes(_input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
- elif isinstance(module, K.container.ImageSequentialBase):
- _input = module.inverse_boxes(_input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
- elif isinstance(module, (K.auto.operations.OperationBase,)):
- return BoxSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
- return _input
- class KeypointSequentialOps(SequentialOpsInterface[Keypoints]):
- """Apply and inverse transformations for keypoints tensors.
- This is for transform keypoints in the format (B, N, 2).
- """
- @classmethod
- def transform(
- cls, input: Keypoints, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Keypoints:
- """Apply a transformation with respect to the parameters.
- Args:
- input: the input tensor, (B, N, 4, 2) or (B, 4, 2).
- module: any torch Module but only kornia augmentation modules will count
- to apply transformations.
- param: the corresponding parameters to the module.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- if extra_args is None:
- extra_args = {}
- _input = input.clone()
- if isinstance(module, (K.GeometricAugmentationBase2D,)):
- _input = module.transform_keypoints(
- _input,
- cls.get_instance_module_param(param),
- module.flags,
- transform=module.transform_matrix,
- **extra_args,
- )
- elif isinstance(module, (K.GeometricAugmentationBase3D,)):
- raise NotImplementedError(
- "The support for 3d keypoint operations are not yet supported. "
- "You are welcome to file a PR in our repo."
- )
- elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
- _input = module.transform_keypoints(
- _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
- )
- elif isinstance(module, K.container.ImageSequentialBase):
- _input = module.transform_keypoints(
- _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
- )
- elif isinstance(module, (K.auto.operations.OperationBase,)):
- return KeypointSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)
- return _input
- @classmethod
- def inverse(
- cls, input: Keypoints, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
- ) -> Keypoints:
- """Inverse a transformation with respect to the parameters.
- Args:
- input: the input tensor.
- module: any torch Module but only kornia augmentation modules will count
- to apply transformations.
- param: the corresponding parameters to the module.
- extra_args: Optional dictionary of extra arguments with specific options for different input types.
- """
- if extra_args is None:
- extra_args = {}
- _input = input.clone()
- if isinstance(module, (K.GeometricAugmentationBase2D,)):
- if module.transform_matrix is None:
- raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
- transform = module.compute_inverse_transformation(module.transform_matrix)
- _input = module.inverse_keypoints(
- _input, cls.get_instance_module_param(param), module.flags, transform=transform, **extra_args
- )
- elif isinstance(module, (K.GeometricAugmentationBase3D,)):
- raise NotImplementedError(
- "The support for 3d keypoint operations are not yet supported. "
- "You are welcome to file a PR in our repo."
- )
- elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
- _input = module.inverse_keypoints(
- _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
- )
- elif isinstance(module, K.container.ImageSequentialBase):
- _input = module.inverse_keypoints(
- _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
- )
- elif isinstance(module, (K.auto.operations.OperationBase,)):
- return KeypointSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
- return _input
|