| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import warnings
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
- import torch
- from kornia.augmentation._2d.base import RigidAffineAugmentationBase2D
- from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D
- from kornia.augmentation.base import _AugmentationBase
- from kornia.constants import DataKey, Resample
- from kornia.core import Module, Tensor
- from kornia.geometry.boxes import Boxes, VideoBoxes
- from kornia.geometry.keypoints import Keypoints, VideoKeypoints
- from kornia.utils import eye_like, is_autocast_enabled
- from .base import TransformMatrixMinIn
- from .image import ImageSequential
- from .ops import AugmentationSequentialOps, DataType
- from .params import ParamItem
- from .patch import PatchSequential
- from .video import VideoSequential
- __all__ = ["AugmentationSequential"]
- _BOXES_OPTIONS = {DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH}
- _KEYPOINTS_OPTIONS = {DataKey.KEYPOINTS}
- _IMG_OPTIONS = {DataKey.INPUT, DataKey.IMAGE}
- _MSK_OPTIONS = {DataKey.MASK}
- _CLS_OPTIONS = {DataKey.CLASS, DataKey.LABEL}
- MaskDataType = Union[Tensor, List[Tensor]]
- class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
- r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once.
- .. image:: _static/img/AugmentationSequential.png
- Args:
- *args: a list of kornia augmentation modules.
- data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask",
- "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label".
- same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise
- settings.
- keepdim: whether to keep the output shape the same as input (True) or broadcast it to the batch form (False).
- If None, it will not overwrite the function-wise settings.
- random_apply: randomly select a sublist (order agnostic) of args to apply transformation.
- If int, a fixed number of transformations will be selected.
- If (a,), x number of transformations (a <= x <= len(args)) will be selected.
- If (a, b), x number of transformations (a <= x <= b) will be selected.
- If True, the whole list of args will be processed as a sequence in a random order.
- If False, the whole list of args will be processed as a sequence in original order.
- transformation_matrix_mode: computation mode for the chained transformation matrix, via `.transform_matrix`
- attribute.
- If `silent`, transformation matrix will be computed silently and the non-rigid
- modules will be ignored as identity transformations.
- If `rigid`, transformation matrix will be computed silently and the non-rigid
- modules will trigger errors.
- If `skip`, transformation matrix will be totally ignored.
- extra_args: to control the behaviour for each datakeys. By default, masks are handled by nearest interpolation
- strategies.
- .. note::
- Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input"/"image" data key.
- It is not clear how to deal with the conversions of masks, bounding boxes and keypoints.
- .. note::
- See a working example `here <https://kornia.github.io/tutorials/nbs/data_augmentation_sequential.html>`__.
- Examples:
- >>> import kornia
- >>> input = torch.randn(2, 3, 5, 6)
- >>> mask = torch.ones(2, 3, 5, 6)
- >>> bbox = torch.tensor([[
- ... [1., 1.],
- ... [2., 1.],
- ... [2., 2.],
- ... [1., 2.],
- ... ]]).expand(2, 1, -1, -1)
- >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)
- >>> aug_list = AugmentationSequential(
- ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
- ... kornia.augmentation.RandomAffine(360, p=1.0),
- ... data_keys=["input", "mask", "bbox", "keypoints"],
- ... same_on_batch=False,
- ... random_apply=10,
- ... )
- >>> out = aug_list(input, mask, bbox, points)
- >>> [o.shape for o in out]
- [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])]
- >>> # apply the exact augmentation again.
- >>> out_rep = aug_list(input, mask, bbox, points, params=aug_list._params)
- >>> [(o == o_rep).all() for o, o_rep in zip(out, out_rep)]
- [tensor(True), tensor(True), tensor(True), tensor(True)]
- >>> # inverse the augmentations
- >>> out_inv = aug_list.inverse(*out)
- >>> [o.shape for o in out_inv]
- [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])]
- This example demonstrates the integration of VideoSequential and AugmentationSequential.
- >>> import kornia
- >>> input = torch.randn(2, 3, 5, 6)[None]
- >>> mask = torch.ones(2, 3, 5, 6)[None]
- >>> bbox = torch.tensor([[
- ... [1., 1.],
- ... [2., 1.],
- ... [2., 2.],
- ... [1., 2.],
- ... ]]).expand(2, 1, -1, -1)[None]
- >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None]
- >>> aug_list = AugmentationSequential(
- ... VideoSequential(
- ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
- ... kornia.augmentation.RandomAffine(360, p=1.0),
- ... ),
- ... data_keys=["input", "mask", "bbox", "keypoints"]
- ... )
- >>> out = aug_list(input, mask, bbox, points)
- >>> [o.shape for o in out] # doctest: +ELLIPSIS
- [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])]
- Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights``
- in ``AugmentationSequential``.
- >>> import kornia
- >>> input = torch.randn(2, 3, 5, 6)[None]
- >>> mask = torch.ones(2, 3, 5, 6)[None]
- >>> bbox = torch.tensor([[
- ... [1., 1.],
- ... [2., 1.],
- ... [2., 2.],
- ... [1., 2.],
- ... ]]).expand(2, 1, -1, -1)[None]
- >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None]
- >>> aug_list = AugmentationSequential(
- ... VideoSequential(
- ... kornia.augmentation.RandomAffine(360, p=1.0),
- ... ),
- ... VideoSequential(
- ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
- ... ),
- ... data_keys=["input", "mask", "bbox", "keypoints"],
- ... random_apply=1,
- ... random_apply_weights=[0.5, 0.3]
- ... )
- >>> out = aug_list(input, mask, bbox, points)
- >>> [o.shape for o in out] # doctest: +ELLIPSIS
- [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])]
- This example shows how to use a list of masks and boxes within AugmentationSequential
- >>> import kornia.augmentation as K
- >>> input = torch.randn(2, 3, 256, 256)
- >>> mask = [torch.ones(1, 3, 256, 256), torch.ones(1, 2, 256, 256)]
- >>> bbox = [
- ... torch.tensor([[28.0, 53.0, 143.0, 164.0], [254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]),
- ... torch.tensor([[254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]])
- ... ]
- >>> bbox = [Boxes.from_tensor(i).data for i in bbox]
- >>> aug_list = K.AugmentationSequential(
- ... K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
- ... K.RandomHorizontalFlip(p=1.0),
- ... K.ImageSequential(K.RandomHorizontalFlip(p=1.0)),
- ... K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)),
- ... data_keys=["input", "mask", "bbox"],
- ... same_on_batch=False,
- ... random_apply=10,
- ... )
- >>> out = aug_list(input, mask, bbox)
- How to use a dictionary as input with AugmentationSequential? The dictionary keys that start with
- one of the available datakeys will be augmented accordingly. Otherwise, the dictionary item is passed
- without any augmentation.
- >>> import kornia.augmentation as K
- >>> img = torch.randn(1, 3, 256, 256)
- >>> mask = [torch.ones(1, 3, 256, 256), torch.ones(1, 2, 256, 256)]
- >>> bbox = [
- ... torch.tensor([[28.0, 53.0, 143.0, 164.0], [254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]),
- ... torch.tensor([[254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]])
- ... ]
- >>> bbox = [Boxes.from_tensor(i).data for i in bbox]
- >>> aug_dict = K.AugmentationSequential(
- ... K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
- ... K.RandomHorizontalFlip(p=1.0),
- ... K.ImageSequential(K.RandomHorizontalFlip(p=1.0)),
- ... K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)),
- ... data_keys=None,
- ... same_on_batch=False,
- ... random_apply=10,
- ... )
- >>> data = {'image': img, 'mask': mask[0], 'mask-b': mask[1], 'bbox': bbox[0], 'bbox-other':bbox[1]}
- >>> out = aug_dict(data)
- >>> out.keys()
- dict_keys(['image', 'mask', 'mask-b', 'bbox', 'bbox-other'])
- """
- input_dtype = None
- mask_dtype = None
- def __init__(
- self,
- *args: Union[_AugmentationBase, ImageSequential],
- data_keys: Optional[Union[Sequence[str], Sequence[int], Sequence[DataKey]]] = (DataKey.INPUT,),
- same_on_batch: Optional[bool] = None,
- keepdim: Optional[bool] = None,
- random_apply: Union[int, bool, Tuple[int, int]] = False,
- random_apply_weights: Optional[List[float]] = None,
- transformation_matrix_mode: str = "silent",
- extra_args: Optional[Dict[DataKey, Dict[str, Any]]] = None,
- ) -> None:
- self._transform_matrix: Optional[Tensor]
- self._transform_matrices: List[Optional[Tensor]] = []
- super().__init__(
- *args,
- same_on_batch=same_on_batch,
- keepdim=keepdim,
- random_apply=random_apply,
- random_apply_weights=random_apply_weights,
- )
- self._parse_transformation_matrix_mode(transformation_matrix_mode)
- self._valid_ops_for_transform_computation: Tuple[Any, ...] = (
- RigidAffineAugmentationBase2D,
- RigidAffineAugmentationBase3D,
- AugmentationSequential,
- )
- self.data_keys: Optional[List[DataKey]]
- if data_keys is not None:
- self.data_keys = [DataKey.get(inp) for inp in data_keys]
- else:
- self.data_keys = data_keys
- if self.data_keys:
- if any(in_type not in DataKey for in_type in self.data_keys):
- raise AssertionError(f"`data_keys` must be in {DataKey}. Got {self.data_keys}.")
- if self.data_keys[0] != DataKey.INPUT:
- raise NotImplementedError(f"The first input must be {DataKey.INPUT}.")
- self.transform_op = AugmentationSequentialOps(self.data_keys)
- self.contains_video_sequential: bool = False
- self.contains_3d_augmentation: bool = False
- for arg in args:
- if isinstance(arg, PatchSequential) and not arg.is_intensity_only():
- warnings.warn(
- "Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask.", stacklevel=1
- )
- if isinstance(arg, VideoSequential):
- self.contains_video_sequential = True
- # NOTE: only for images are supported for 3D.
- if isinstance(arg, AugmentationBase3D):
- self.contains_3d_augmentation = True
- self._transform_matrix = None
- self.extra_args = extra_args or {DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}}
- def clear_state(self) -> None:
- self._reset_transform_matrix_state()
- return super().clear_state()
- def _update_transform_matrix_for_valid_op(self, module: Module) -> None:
- self._transform_matrices.append(module.transform_matrix)
- def identity_matrix(self, input: Tensor) -> Tensor:
- """Return identity matrix."""
- if self.contains_3d_augmentation:
- return eye_like(4, input)
- else:
- return eye_like(3, input)
- def inverse( # type: ignore[override]
- self,
- *args: Union[DataType, Dict[str, DataType]],
- params: Optional[List[ParamItem]] = None,
- data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
- ) -> Union[DataType, List[DataType], Dict[str, DataType]]:
- """Reverse the transformation applied.
- Number of input tensors must align with the number of``data_keys``. If ``data_keys`` is not set, use
- ``self.data_keys`` by default.
- """
- original_keys = None
- if len(args) == 1 and isinstance(args[0], dict):
- original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0])
- # args here should already be `DataType`
- # NOTE: how to right type to: unpacked args <-> tuple of args to unpack
- # issue with `self._preproc_dict_data` return args type
- self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys)
- self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys) # type: ignore
- in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys) # type: ignore
- if params is None:
- if self._params is None:
- raise ValueError(
- "No parameters available for inversing, please run a forward pass first "
- "or passing valid params into this function."
- )
- params = self._params
- outputs: List[DataType] = in_args
- for param in params[::-1]:
- module = self.get_submodule(param.name)
- outputs = self.transform_op.inverse( # type: ignore
- *outputs, module=module, param=param, extra_args=self.extra_args
- )
- if not isinstance(outputs, (list, tuple)):
- # Make sure we are unpacking a list whilst post-proc
- outputs = [outputs]
- outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore
- if isinstance(original_keys, tuple):
- result = {k: v for v, k in zip(outputs, original_keys)}
- if invalid_data:
- result.update(invalid_data)
- return result
- if len(outputs) == 1 and isinstance(outputs, list):
- return outputs[0]
- return outputs
- def _validate_args_datakeys(self, *args: DataType, data_keys: List[DataKey]) -> None:
- if len(args) != len(data_keys):
- raise AssertionError(
- f"The number of inputs must align with the number of data_keys. Got {len(args)} and {len(data_keys)}."
- )
- # TODO: validate args batching, and its consistency
- def _arguments_preproc(self, *args: DataType, data_keys: List[DataKey]) -> List[DataType]:
- inp: List[DataType] = []
- for arg, dcate in zip(args, data_keys):
- if DataKey.get(dcate) in _IMG_OPTIONS:
- arg = cast(Tensor, arg)
- self.input_dtype = arg.dtype
- inp.append(arg)
- elif DataKey.get(dcate) in _MSK_OPTIONS:
- if isinstance(inp, list):
- arg = cast(List[Tensor], arg)
- self.mask_dtype = arg[0].dtype
- else:
- arg = cast(Tensor, arg)
- self.mask_dtype = arg.dtype
- inp.append(self._preproc_mask(arg))
- elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
- inp.append(self._preproc_keypoints(arg, dcate))
- elif DataKey.get(dcate) in _BOXES_OPTIONS:
- inp.append(self._preproc_boxes(arg, dcate))
- elif DataKey.get(dcate) in _CLS_OPTIONS:
- inp.append(arg)
- else:
- raise NotImplementedError(f"input type of {dcate} is not implemented.")
- return inp
- def _arguments_postproc(
- self, in_args: List[DataType], out_args: List[DataType], data_keys: List[DataKey]
- ) -> List[DataType]:
- out: List[DataType] = []
- for in_arg, out_arg, dcate in zip(in_args, out_args, data_keys):
- if DataKey.get(dcate) in _IMG_OPTIONS:
- # It is tensor type already.
- out.append(out_arg)
- # TODO: may add the float to integer (for masks), etc.
- elif DataKey.get(dcate) in _MSK_OPTIONS:
- _out_m = self._postproc_mask(cast(MaskDataType, out_arg))
- out.append(_out_m)
- elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
- _out_k = self._postproc_keypoint(in_arg, cast(Keypoints, out_arg), dcate)
- if is_autocast_enabled() and isinstance(in_arg, (Tensor, Keypoints)):
- if isinstance(_out_k, list):
- _out_k = [i.type(in_arg.dtype) for i in _out_k]
- else:
- _out_k = _out_k.type(in_arg.dtype)
- out.append(_out_k)
- elif DataKey.get(dcate) in _BOXES_OPTIONS:
- _out_b = self._postproc_boxes(in_arg, cast(Boxes, out_arg), dcate)
- if is_autocast_enabled() and isinstance(in_arg, (Tensor, Boxes)):
- if isinstance(_out_b, list):
- _out_b = [i.type(in_arg.dtype) for i in _out_b]
- else:
- _out_b = _out_b.type(in_arg.dtype)
- out.append(_out_b)
- elif DataKey.get(dcate) in _CLS_OPTIONS:
- out.append(out_arg)
- else:
- raise NotImplementedError(f"input type of {dcate} is not implemented.")
- return out
- def forward( # type: ignore[override]
- self,
- *args: Union[DataType, Dict[str, DataType]],
- params: Optional[List[ParamItem]] = None,
- data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
- ) -> Union[DataType, List[DataType], Dict[str, DataType]]:
- """Compute multiple tensors simultaneously according to ``self.data_keys``."""
- self.clear_state()
- # Unpack/handle dictionary args
- original_keys = None
- if len(args) == 1 and isinstance(args[0], dict):
- original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0])
- self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys)
- self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys) # type: ignore
- in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys) # type: ignore
- if params is None:
- # image data must exist if params is not provided.
- if DataKey.INPUT in self.transform_op.data_keys:
- inp = in_args[self.transform_op.data_keys.index(DataKey.INPUT)]
- if not isinstance(inp, (Tensor,)):
- raise ValueError(f"`INPUT` should be a tensor but `{type(inp)}` received.")
- # A video input shall be BCDHW while an image input shall be BCHW
- if self.contains_video_sequential or self.contains_3d_augmentation:
- _, out_shape = self.autofill_dim(inp, dim_range=(3, 5))
- else:
- _, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
- params = self.forward_parameters(out_shape)
- else:
- raise ValueError("`params` must be provided whilst INPUT is not in data_keys.")
- outputs: Union[Tensor, List[DataType]] = in_args
- for param in params:
- module = self.get_submodule(param.name)
- outputs = self.transform_op.transform( # type: ignore
- *outputs, module=module, param=param, extra_args=self.extra_args
- )
- if not isinstance(outputs, (list, tuple)):
- # Make sure we are unpacking a list whilst post-proc
- outputs = [outputs]
- self._update_transform_matrix_by_module(module)
- outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore
- # Restore it back
- self.transform_op.data_keys = self.data_keys
- self._params = params
- if isinstance(original_keys, tuple):
- result = {k: v for v, k in zip(outputs, original_keys)}
- if invalid_data:
- result.update(invalid_data)
- return result
- if len(outputs) == 1 and isinstance(outputs, list):
- return outputs[0]
- return outputs
- def __call__(
- self,
- *inputs: Any,
- input_names_to_handle: Optional[List[Any]] = None,
- output_type: str = "tensor",
- **kwargs: Any,
- ) -> Any:
- """Overwrite the __call__ function to handle various inputs.
- Args:
- inputs: Inputs to operate on.
- input_names_to_handle: List of input names to convert, if None, handle all inputs.
- output_type: Desired output type ('tensor', 'numpy', or 'pil').
- kwargs: Additional arguments.
- Returns:
- Callable: Decorated function with converted input and output types.
- """
- # Wrap the forward method with the decorator
- if not self._disable_features:
- # TODO: Some more behaviour for AugmentationSequential needs to be revisited later
- # e.g. We convert only images, etc.
- decorated_forward = self.convert_input_output(
- input_names_to_handle=input_names_to_handle, output_type=output_type
- )(super(ImageSequential, self).__call__)
- _output_image = decorated_forward(*inputs, **kwargs)
- in_data_keys: Optional[List[DataKey]]
- if len(inputs) == 1 and isinstance(inputs[0], dict):
- original_keys, in_data_keys, inputs, _invalid_data = self._preproc_dict_data(inputs[0])
- else:
- in_data_keys = kwargs.get("data_keys", self.data_keys)
- data_keys = self.transform_op.preproc_datakeys(in_data_keys)
- if len(data_keys) > 1 and DataKey.INPUT in data_keys:
- # NOTE: we may update it later for more supports of drawing boxes, etc.
- idx = data_keys.index(DataKey.INPUT)
- if output_type == "tensor":
- self._output_image = _output_image
- if isinstance(_output_image, dict):
- self._output_image[original_keys[idx]] = _output_image[original_keys[idx]]
- else:
- self._output_image[idx] = _output_image[idx]
- elif isinstance(_output_image, dict):
- self._output_image[original_keys[idx]] = _output_image[original_keys[idx]]
- else:
- self._output_image[idx] = _output_image[idx]
- else:
- self._output_image = _output_image
- else:
- _output_image = super(ImageSequential, self).__call__(*inputs, **kwargs)
- return _output_image
- def _preproc_dict_data(
- self, data: Dict[str, DataType]
- ) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...], Optional[Dict[str, Any]]]:
- if self.data_keys is not None:
- raise ValueError("If you are using a dictionary as input, the data_keys should be None.")
- keys = tuple(data.keys())
- data_keys, invalid_keys = self._read_datakeys_from_dict(keys)
- invalid_data = {i: data.pop(i) for i in invalid_keys} if invalid_keys else None
- keys = tuple(k for k in keys if k not in invalid_keys) if invalid_keys else keys
- data_unpacked = tuple(data.values())
- return keys, data_keys, data_unpacked, invalid_data
- def _read_datakeys_from_dict(self, keys: Sequence[str]) -> Tuple[List[DataKey], Optional[List[str]]]:
- def retrieve_key(key: str) -> DataKey:
- """Try to retrieve the datakey value by matching `<datakey>*`."""
- # Alias cases, like INPUT, will not be get by the enum iterator.
- if key.upper().startswith("INPUT"):
- return DataKey.INPUT
- for dk in DataKey:
- if key.upper() in {"BBOX_XYXY", "BBOX_XYWH"}:
- return DataKey.get(key.upper())
- if key.upper().startswith(dk.name):
- return DataKey.get(dk.name)
- allowed_dk = " | ".join(f"`{d.name}`" for d in DataKey)
- raise ValueError(
- f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`"
- )
- valid_data_keys = []
- invalid_keys = []
- for k in keys:
- try:
- valid_data_keys.append(DataKey.get(retrieve_key(k)))
- except ValueError:
- invalid_keys.append(k)
- return valid_data_keys, invalid_keys
- def _preproc_mask(self, arg: MaskDataType) -> MaskDataType:
- if isinstance(arg, list):
- new_arg = []
- for a in arg:
- a_new = a.to(self.input_dtype) if self.input_dtype else a.to(torch.float)
- new_arg.append(a_new)
- return new_arg
- else:
- arg = arg.to(self.input_dtype) if self.input_dtype else arg.to(torch.float)
- return arg
- def _postproc_mask(self, arg: MaskDataType) -> MaskDataType:
- if isinstance(arg, list):
- new_arg = []
- for a in arg:
- a_new = a.to(self.mask_dtype) if self.mask_dtype else a.to(torch.float)
- new_arg.append(a_new)
- return new_arg
- else:
- arg = arg.to(self.mask_dtype) if self.mask_dtype else arg.to(torch.float)
- return arg
- def _preproc_boxes(self, arg: DataType, dcate: DataKey) -> Boxes:
- if DataKey.get(dcate) in [DataKey.BBOX]:
- mode = "vertices_plus"
- elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]:
- mode = "xyxy_plus"
- elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]:
- mode = "xywh"
- else:
- raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.")
- if isinstance(arg, (Boxes,)):
- return arg
- elif self.contains_video_sequential:
- arg = cast(Tensor, arg)
- return VideoBoxes.from_tensor(arg)
- elif self.contains_3d_augmentation:
- raise NotImplementedError("3D box handlers are not yet supported.")
- else:
- arg = cast(Tensor, arg)
- return Boxes.from_tensor(arg, mode=mode)
- def _postproc_boxes(self, in_arg: DataType, out_arg: Boxes, dcate: DataKey) -> Union[Tensor, List[Tensor], Boxes]:
- if DataKey.get(dcate) in [DataKey.BBOX]:
- mode = "vertices_plus"
- elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]:
- mode = "xyxy_plus"
- elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]:
- mode = "xywh"
- else:
- raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.")
- # TODO: handle 3d scenarios
- if isinstance(in_arg, (Boxes,)):
- return out_arg
- else:
- return out_arg.to_tensor(mode=mode)
- def _preproc_keypoints(self, arg: DataType, dcate: DataKey) -> Keypoints:
- dtype = None
- if self.contains_video_sequential:
- arg = cast(Union[Tensor, List[Tensor]], arg)
- if isinstance(arg, list):
- if not torch.is_floating_point(arg[0]):
- dtype = arg[0].dtype
- arg = [a.float() for a in arg]
- elif not torch.is_floating_point(arg):
- dtype = arg.dtype
- arg = arg.float()
- video_result = VideoKeypoints.from_tensor(arg)
- return video_result.type(dtype) if dtype else video_result
- elif self.contains_3d_augmentation:
- raise NotImplementedError("3D keypoint handlers are not yet supported.")
- elif isinstance(arg, (Keypoints,)):
- return arg
- else:
- arg = cast(Tensor, arg)
- if not torch.is_floating_point(arg):
- dtype = arg.dtype
- arg = arg.float()
- # TODO: Add List[Tensor] in the future.
- result = Keypoints.from_tensor(arg)
- return result.type(dtype) if dtype else result
- def _postproc_keypoint(
- self, in_arg: DataType, out_arg: Keypoints, dcate: DataKey
- ) -> Union[Tensor, List[Tensor], Keypoints]:
- if isinstance(in_arg, (Keypoints,)):
- return out_arg
- else:
- return out_arg.to_tensor()
|