# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import warnings from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast import torch from kornia.augmentation._2d.base import RigidAffineAugmentationBase2D from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D from kornia.augmentation.base import _AugmentationBase from kornia.constants import DataKey, Resample from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes, VideoBoxes from kornia.geometry.keypoints import Keypoints, VideoKeypoints from kornia.utils import eye_like, is_autocast_enabled from .base import TransformMatrixMinIn from .image import ImageSequential from .ops import AugmentationSequentialOps, DataType from .params import ParamItem from .patch import PatchSequential from .video import VideoSequential __all__ = ["AugmentationSequential"] _BOXES_OPTIONS = {DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH} _KEYPOINTS_OPTIONS = {DataKey.KEYPOINTS} _IMG_OPTIONS = {DataKey.INPUT, DataKey.IMAGE} _MSK_OPTIONS = {DataKey.MASK} _CLS_OPTIONS = {DataKey.CLASS, DataKey.LABEL} MaskDataType = Union[Tensor, List[Tensor]] class AugmentationSequential(TransformMatrixMinIn, ImageSequential): r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once. .. image:: _static/img/AugmentationSequential.png Args: *args: a list of kornia augmentation modules. data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label". same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise settings. keepdim: whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). If None, it will not overwrite the function-wise settings. random_apply: randomly select a sublist (order agnostic) of args to apply transformation. If int, a fixed number of transformations will be selected. If (a,), x number of transformations (a <= x <= len(args)) will be selected. If (a, b), x number of transformations (a <= x <= b) will be selected. If True, the whole list of args will be processed as a sequence in a random order. If False, the whole list of args will be processed as a sequence in original order. transformation_matrix_mode: computation mode for the chained transformation matrix, via `.transform_matrix` attribute. If `silent`, transformation matrix will be computed silently and the non-rigid modules will be ignored as identity transformations. If `rigid`, transformation matrix will be computed silently and the non-rigid modules will trigger errors. If `skip`, transformation matrix will be totally ignored. extra_args: to control the behaviour for each datakeys. By default, masks are handled by nearest interpolation strategies. .. note:: Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input"/"image" data key. It is not clear how to deal with the conversions of masks, bounding boxes and keypoints. .. note:: See a working example `here `__. Examples: >>> import kornia >>> input = torch.randn(2, 3, 5, 6) >>> mask = torch.ones(2, 3, 5, 6) >>> bbox = torch.tensor([[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.], ... ]]).expand(2, 1, -1, -1) >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1) >>> aug_list = AugmentationSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... kornia.augmentation.RandomAffine(360, p=1.0), ... data_keys=["input", "mask", "bbox", "keypoints"], ... same_on_batch=False, ... random_apply=10, ... ) >>> out = aug_list(input, mask, bbox, points) >>> [o.shape for o in out] [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])] >>> # apply the exact augmentation again. >>> out_rep = aug_list(input, mask, bbox, points, params=aug_list._params) >>> [(o == o_rep).all() for o, o_rep in zip(out, out_rep)] [tensor(True), tensor(True), tensor(True), tensor(True)] >>> # inverse the augmentations >>> out_inv = aug_list.inverse(*out) >>> [o.shape for o in out_inv] [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])] This example demonstrates the integration of VideoSequential and AugmentationSequential. >>> import kornia >>> input = torch.randn(2, 3, 5, 6)[None] >>> mask = torch.ones(2, 3, 5, 6)[None] >>> bbox = torch.tensor([[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.], ... ]]).expand(2, 1, -1, -1)[None] >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None] >>> aug_list = AugmentationSequential( ... VideoSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... kornia.augmentation.RandomAffine(360, p=1.0), ... ), ... data_keys=["input", "mask", "bbox", "keypoints"] ... ) >>> out = aug_list(input, mask, bbox, points) >>> [o.shape for o in out] # doctest: +ELLIPSIS [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])] Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights`` in ``AugmentationSequential``. >>> import kornia >>> input = torch.randn(2, 3, 5, 6)[None] >>> mask = torch.ones(2, 3, 5, 6)[None] >>> bbox = torch.tensor([[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.], ... ]]).expand(2, 1, -1, -1)[None] >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None] >>> aug_list = AugmentationSequential( ... VideoSequential( ... kornia.augmentation.RandomAffine(360, p=1.0), ... ), ... VideoSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... ), ... data_keys=["input", "mask", "bbox", "keypoints"], ... random_apply=1, ... random_apply_weights=[0.5, 0.3] ... ) >>> out = aug_list(input, mask, bbox, points) >>> [o.shape for o in out] # doctest: +ELLIPSIS [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])] This example shows how to use a list of masks and boxes within AugmentationSequential >>> import kornia.augmentation as K >>> input = torch.randn(2, 3, 256, 256) >>> mask = [torch.ones(1, 3, 256, 256), torch.ones(1, 2, 256, 256)] >>> bbox = [ ... torch.tensor([[28.0, 53.0, 143.0, 164.0], [254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]), ... torch.tensor([[254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]) ... ] >>> bbox = [Boxes.from_tensor(i).data for i in bbox] >>> aug_list = K.AugmentationSequential( ... K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... K.RandomHorizontalFlip(p=1.0), ... K.ImageSequential(K.RandomHorizontalFlip(p=1.0)), ... K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)), ... data_keys=["input", "mask", "bbox"], ... same_on_batch=False, ... random_apply=10, ... ) >>> out = aug_list(input, mask, bbox) How to use a dictionary as input with AugmentationSequential? The dictionary keys that start with one of the available datakeys will be augmented accordingly. Otherwise, the dictionary item is passed without any augmentation. >>> import kornia.augmentation as K >>> img = torch.randn(1, 3, 256, 256) >>> mask = [torch.ones(1, 3, 256, 256), torch.ones(1, 2, 256, 256)] >>> bbox = [ ... torch.tensor([[28.0, 53.0, 143.0, 164.0], [254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]), ... torch.tensor([[254.0, 158.0, 364.0, 290.0], [307.0, 204.0, 413.0, 350.0]]) ... ] >>> bbox = [Boxes.from_tensor(i).data for i in bbox] >>> aug_dict = K.AugmentationSequential( ... K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... K.RandomHorizontalFlip(p=1.0), ... K.ImageSequential(K.RandomHorizontalFlip(p=1.0)), ... K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)), ... data_keys=None, ... same_on_batch=False, ... random_apply=10, ... ) >>> data = {'image': img, 'mask': mask[0], 'mask-b': mask[1], 'bbox': bbox[0], 'bbox-other':bbox[1]} >>> out = aug_dict(data) >>> out.keys() dict_keys(['image', 'mask', 'mask-b', 'bbox', 'bbox-other']) """ input_dtype = None mask_dtype = None def __init__( self, *args: Union[_AugmentationBase, ImageSequential], data_keys: Optional[Union[Sequence[str], Sequence[int], Sequence[DataKey]]] = (DataKey.INPUT,), same_on_batch: Optional[bool] = None, keepdim: Optional[bool] = None, random_apply: Union[int, bool, Tuple[int, int]] = False, random_apply_weights: Optional[List[float]] = None, transformation_matrix_mode: str = "silent", extra_args: Optional[Dict[DataKey, Dict[str, Any]]] = None, ) -> None: self._transform_matrix: Optional[Tensor] self._transform_matrices: List[Optional[Tensor]] = [] super().__init__( *args, same_on_batch=same_on_batch, keepdim=keepdim, random_apply=random_apply, random_apply_weights=random_apply_weights, ) self._parse_transformation_matrix_mode(transformation_matrix_mode) self._valid_ops_for_transform_computation: Tuple[Any, ...] = ( RigidAffineAugmentationBase2D, RigidAffineAugmentationBase3D, AugmentationSequential, ) self.data_keys: Optional[List[DataKey]] if data_keys is not None: self.data_keys = [DataKey.get(inp) for inp in data_keys] else: self.data_keys = data_keys if self.data_keys: if any(in_type not in DataKey for in_type in self.data_keys): raise AssertionError(f"`data_keys` must be in {DataKey}. Got {self.data_keys}.") if self.data_keys[0] != DataKey.INPUT: raise NotImplementedError(f"The first input must be {DataKey.INPUT}.") self.transform_op = AugmentationSequentialOps(self.data_keys) self.contains_video_sequential: bool = False self.contains_3d_augmentation: bool = False for arg in args: if isinstance(arg, PatchSequential) and not arg.is_intensity_only(): warnings.warn( "Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask.", stacklevel=1 ) if isinstance(arg, VideoSequential): self.contains_video_sequential = True # NOTE: only for images are supported for 3D. if isinstance(arg, AugmentationBase3D): self.contains_3d_augmentation = True self._transform_matrix = None self.extra_args = extra_args or {DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}} def clear_state(self) -> None: self._reset_transform_matrix_state() return super().clear_state() def _update_transform_matrix_for_valid_op(self, module: Module) -> None: self._transform_matrices.append(module.transform_matrix) def identity_matrix(self, input: Tensor) -> Tensor: """Return identity matrix.""" if self.contains_3d_augmentation: return eye_like(4, input) else: return eye_like(3, input) def inverse( # type: ignore[override] self, *args: Union[DataType, Dict[str, DataType]], params: Optional[List[ParamItem]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, ) -> Union[DataType, List[DataType], Dict[str, DataType]]: """Reverse the transformation applied. Number of input tensors must align with the number of``data_keys``. If ``data_keys`` is not set, use ``self.data_keys`` by default. """ original_keys = None if len(args) == 1 and isinstance(args[0], dict): original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0]) # args here should already be `DataType` # NOTE: how to right type to: unpacked args <-> tuple of args to unpack # issue with `self._preproc_dict_data` return args type self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys) self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys) # type: ignore in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys) # type: ignore if params is None: if self._params is None: raise ValueError( "No parameters available for inversing, please run a forward pass first " "or passing valid params into this function." ) params = self._params outputs: List[DataType] = in_args for param in params[::-1]: module = self.get_submodule(param.name) outputs = self.transform_op.inverse( # type: ignore *outputs, module=module, param=param, extra_args=self.extra_args ) if not isinstance(outputs, (list, tuple)): # Make sure we are unpacking a list whilst post-proc outputs = [outputs] outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore if isinstance(original_keys, tuple): result = {k: v for v, k in zip(outputs, original_keys)} if invalid_data: result.update(invalid_data) return result if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] return outputs def _validate_args_datakeys(self, *args: DataType, data_keys: List[DataKey]) -> None: if len(args) != len(data_keys): raise AssertionError( f"The number of inputs must align with the number of data_keys. Got {len(args)} and {len(data_keys)}." ) # TODO: validate args batching, and its consistency def _arguments_preproc(self, *args: DataType, data_keys: List[DataKey]) -> List[DataType]: inp: List[DataType] = [] for arg, dcate in zip(args, data_keys): if DataKey.get(dcate) in _IMG_OPTIONS: arg = cast(Tensor, arg) self.input_dtype = arg.dtype inp.append(arg) elif DataKey.get(dcate) in _MSK_OPTIONS: if isinstance(inp, list): arg = cast(List[Tensor], arg) self.mask_dtype = arg[0].dtype else: arg = cast(Tensor, arg) self.mask_dtype = arg.dtype inp.append(self._preproc_mask(arg)) elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS: inp.append(self._preproc_keypoints(arg, dcate)) elif DataKey.get(dcate) in _BOXES_OPTIONS: inp.append(self._preproc_boxes(arg, dcate)) elif DataKey.get(dcate) in _CLS_OPTIONS: inp.append(arg) else: raise NotImplementedError(f"input type of {dcate} is not implemented.") return inp def _arguments_postproc( self, in_args: List[DataType], out_args: List[DataType], data_keys: List[DataKey] ) -> List[DataType]: out: List[DataType] = [] for in_arg, out_arg, dcate in zip(in_args, out_args, data_keys): if DataKey.get(dcate) in _IMG_OPTIONS: # It is tensor type already. out.append(out_arg) # TODO: may add the float to integer (for masks), etc. elif DataKey.get(dcate) in _MSK_OPTIONS: _out_m = self._postproc_mask(cast(MaskDataType, out_arg)) out.append(_out_m) elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS: _out_k = self._postproc_keypoint(in_arg, cast(Keypoints, out_arg), dcate) if is_autocast_enabled() and isinstance(in_arg, (Tensor, Keypoints)): if isinstance(_out_k, list): _out_k = [i.type(in_arg.dtype) for i in _out_k] else: _out_k = _out_k.type(in_arg.dtype) out.append(_out_k) elif DataKey.get(dcate) in _BOXES_OPTIONS: _out_b = self._postproc_boxes(in_arg, cast(Boxes, out_arg), dcate) if is_autocast_enabled() and isinstance(in_arg, (Tensor, Boxes)): if isinstance(_out_b, list): _out_b = [i.type(in_arg.dtype) for i in _out_b] else: _out_b = _out_b.type(in_arg.dtype) out.append(_out_b) elif DataKey.get(dcate) in _CLS_OPTIONS: out.append(out_arg) else: raise NotImplementedError(f"input type of {dcate} is not implemented.") return out def forward( # type: ignore[override] self, *args: Union[DataType, Dict[str, DataType]], params: Optional[List[ParamItem]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, ) -> Union[DataType, List[DataType], Dict[str, DataType]]: """Compute multiple tensors simultaneously according to ``self.data_keys``.""" self.clear_state() # Unpack/handle dictionary args original_keys = None if len(args) == 1 and isinstance(args[0], dict): original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0]) self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys) self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys) # type: ignore in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys) # type: ignore if params is None: # image data must exist if params is not provided. if DataKey.INPUT in self.transform_op.data_keys: inp = in_args[self.transform_op.data_keys.index(DataKey.INPUT)] if not isinstance(inp, (Tensor,)): raise ValueError(f"`INPUT` should be a tensor but `{type(inp)}` received.") # A video input shall be BCDHW while an image input shall be BCHW if self.contains_video_sequential or self.contains_3d_augmentation: _, out_shape = self.autofill_dim(inp, dim_range=(3, 5)) else: _, out_shape = self.autofill_dim(inp, dim_range=(2, 4)) params = self.forward_parameters(out_shape) else: raise ValueError("`params` must be provided whilst INPUT is not in data_keys.") outputs: Union[Tensor, List[DataType]] = in_args for param in params: module = self.get_submodule(param.name) outputs = self.transform_op.transform( # type: ignore *outputs, module=module, param=param, extra_args=self.extra_args ) if not isinstance(outputs, (list, tuple)): # Make sure we are unpacking a list whilst post-proc outputs = [outputs] self._update_transform_matrix_by_module(module) outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore # Restore it back self.transform_op.data_keys = self.data_keys self._params = params if isinstance(original_keys, tuple): result = {k: v for v, k in zip(outputs, original_keys)} if invalid_data: result.update(invalid_data) return result if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] return outputs def __call__( self, *inputs: Any, input_names_to_handle: Optional[List[Any]] = None, output_type: str = "tensor", **kwargs: Any, ) -> Any: """Overwrite the __call__ function to handle various inputs. Args: inputs: Inputs to operate on. input_names_to_handle: List of input names to convert, if None, handle all inputs. output_type: Desired output type ('tensor', 'numpy', or 'pil'). kwargs: Additional arguments. Returns: Callable: Decorated function with converted input and output types. """ # Wrap the forward method with the decorator if not self._disable_features: # TODO: Some more behaviour for AugmentationSequential needs to be revisited later # e.g. We convert only images, etc. decorated_forward = self.convert_input_output( input_names_to_handle=input_names_to_handle, output_type=output_type )(super(ImageSequential, self).__call__) _output_image = decorated_forward(*inputs, **kwargs) in_data_keys: Optional[List[DataKey]] if len(inputs) == 1 and isinstance(inputs[0], dict): original_keys, in_data_keys, inputs, _invalid_data = self._preproc_dict_data(inputs[0]) else: in_data_keys = kwargs.get("data_keys", self.data_keys) data_keys = self.transform_op.preproc_datakeys(in_data_keys) if len(data_keys) > 1 and DataKey.INPUT in data_keys: # NOTE: we may update it later for more supports of drawing boxes, etc. idx = data_keys.index(DataKey.INPUT) if output_type == "tensor": self._output_image = _output_image if isinstance(_output_image, dict): self._output_image[original_keys[idx]] = _output_image[original_keys[idx]] else: self._output_image[idx] = _output_image[idx] elif isinstance(_output_image, dict): self._output_image[original_keys[idx]] = _output_image[original_keys[idx]] else: self._output_image[idx] = _output_image[idx] else: self._output_image = _output_image else: _output_image = super(ImageSequential, self).__call__(*inputs, **kwargs) return _output_image def _preproc_dict_data( self, data: Dict[str, DataType] ) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...], Optional[Dict[str, Any]]]: if self.data_keys is not None: raise ValueError("If you are using a dictionary as input, the data_keys should be None.") keys = tuple(data.keys()) data_keys, invalid_keys = self._read_datakeys_from_dict(keys) invalid_data = {i: data.pop(i) for i in invalid_keys} if invalid_keys else None keys = tuple(k for k in keys if k not in invalid_keys) if invalid_keys else keys data_unpacked = tuple(data.values()) return keys, data_keys, data_unpacked, invalid_data def _read_datakeys_from_dict(self, keys: Sequence[str]) -> Tuple[List[DataKey], Optional[List[str]]]: def retrieve_key(key: str) -> DataKey: """Try to retrieve the datakey value by matching `*`.""" # Alias cases, like INPUT, will not be get by the enum iterator. if key.upper().startswith("INPUT"): return DataKey.INPUT for dk in DataKey: if key.upper() in {"BBOX_XYXY", "BBOX_XYWH"}: return DataKey.get(key.upper()) if key.upper().startswith(dk.name): return DataKey.get(dk.name) allowed_dk = " | ".join(f"`{d.name}`" for d in DataKey) raise ValueError( f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`" ) valid_data_keys = [] invalid_keys = [] for k in keys: try: valid_data_keys.append(DataKey.get(retrieve_key(k))) except ValueError: invalid_keys.append(k) return valid_data_keys, invalid_keys def _preproc_mask(self, arg: MaskDataType) -> MaskDataType: if isinstance(arg, list): new_arg = [] for a in arg: a_new = a.to(self.input_dtype) if self.input_dtype else a.to(torch.float) new_arg.append(a_new) return new_arg else: arg = arg.to(self.input_dtype) if self.input_dtype else arg.to(torch.float) return arg def _postproc_mask(self, arg: MaskDataType) -> MaskDataType: if isinstance(arg, list): new_arg = [] for a in arg: a_new = a.to(self.mask_dtype) if self.mask_dtype else a.to(torch.float) new_arg.append(a_new) return new_arg else: arg = arg.to(self.mask_dtype) if self.mask_dtype else arg.to(torch.float) return arg def _preproc_boxes(self, arg: DataType, dcate: DataKey) -> Boxes: if DataKey.get(dcate) in [DataKey.BBOX]: mode = "vertices_plus" elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]: mode = "xyxy_plus" elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]: mode = "xywh" else: raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.") if isinstance(arg, (Boxes,)): return arg elif self.contains_video_sequential: arg = cast(Tensor, arg) return VideoBoxes.from_tensor(arg) elif self.contains_3d_augmentation: raise NotImplementedError("3D box handlers are not yet supported.") else: arg = cast(Tensor, arg) return Boxes.from_tensor(arg, mode=mode) def _postproc_boxes(self, in_arg: DataType, out_arg: Boxes, dcate: DataKey) -> Union[Tensor, List[Tensor], Boxes]: if DataKey.get(dcate) in [DataKey.BBOX]: mode = "vertices_plus" elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]: mode = "xyxy_plus" elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]: mode = "xywh" else: raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.") # TODO: handle 3d scenarios if isinstance(in_arg, (Boxes,)): return out_arg else: return out_arg.to_tensor(mode=mode) def _preproc_keypoints(self, arg: DataType, dcate: DataKey) -> Keypoints: dtype = None if self.contains_video_sequential: arg = cast(Union[Tensor, List[Tensor]], arg) if isinstance(arg, list): if not torch.is_floating_point(arg[0]): dtype = arg[0].dtype arg = [a.float() for a in arg] elif not torch.is_floating_point(arg): dtype = arg.dtype arg = arg.float() video_result = VideoKeypoints.from_tensor(arg) return video_result.type(dtype) if dtype else video_result elif self.contains_3d_augmentation: raise NotImplementedError("3D keypoint handlers are not yet supported.") elif isinstance(arg, (Keypoints,)): return arg else: arg = cast(Tensor, arg) if not torch.is_floating_point(arg): dtype = arg.dtype arg = arg.float() # TODO: Add List[Tensor] in the future. result = Keypoints.from_tensor(arg) return result.type(dtype) if dtype else result def _postproc_keypoint( self, in_arg: DataType, out_arg: Keypoints, dcate: DataKey ) -> Union[Tensor, List[Tensor], Keypoints]: if isinstance(in_arg, (Keypoints,)): return out_arg else: return out_arg.to_tensor()