| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import List, Optional, Tuple, Union, cast
- import torch
- from torch import Size
- from kornia.core import Tensor
- from kornia.geometry import transform_points
- __all__ = ["Keypoints", "Keypoints3D"]
- def _merge_keypoint_list(keypoints: List[Tensor]) -> Tensor:
- raise NotImplementedError
- class Keypoints:
- """2D Keypoints containing Nx2 or BxNx2 points.
- Args:
- keypoints: Raw tensor or a list of Tensors with the Nx2 coordinates
- raise_if_not_floating_point: will raise if the Tensor isn't float
- """
- def __init__(self, keypoints: Union[Tensor, List[Tensor]], raise_if_not_floating_point: bool = True) -> None:
- self._N: Optional[List[int]] = None
- if isinstance(keypoints, list):
- keypoints, self._N = _merge_keypoint_list(keypoints)
- if not isinstance(keypoints, Tensor):
- raise TypeError(f"Input keypoints is not a Tensor. Got: {type(keypoints)}.")
- if not keypoints.is_floating_point():
- if raise_if_not_floating_point:
- raise ValueError(f"Coordinates must be in floating point. Got {keypoints.dtype}")
- keypoints = keypoints.float()
- if len(keypoints.shape) == 0:
- # Use reshape, so we don't end up creating a new tensor that does not depend on
- # the inputs (and consequently confuses jit)
- keypoints = keypoints.reshape((-1, 2))
- if not (2 <= keypoints.ndim <= 3 and keypoints.shape[-1:] == (2,)):
- raise ValueError(f"Keypoints shape must be (N, 2) or (B, N, 2). Got {keypoints.shape}.")
- self._is_batched = False if keypoints.ndim == 2 else True
- self._data = keypoints
- def __getitem__(self, key: Union[slice, int, Tensor]) -> "Keypoints":
- new_obj = type(self)(self._data[key], False)
- return new_obj
- def __setitem__(self, key: Union[slice, int, Tensor], value: "Keypoints") -> "Keypoints":
- self._data[key] = value._data
- return self
- @property
- def shape(self) -> Union[Tuple[int, ...], Size]:
- return self.data.shape
- @property
- def data(self) -> Tensor:
- return self._data
- @property
- def device(self) -> torch.device:
- """Returns keypoints device."""
- return self._data.device
- @property
- def dtype(self) -> torch.dtype:
- """Returns keypoints dtype."""
- return self._data.dtype
- def index_put(
- self,
- indices: Union[Tuple[Tensor, ...], List[Tensor]],
- values: Union[Tensor, "Keypoints"],
- inplace: bool = False,
- ) -> "Keypoints":
- if inplace:
- _data = self._data
- else:
- _data = self._data.clone()
- if isinstance(values, Keypoints):
- _data.index_put_(indices, values.data)
- else:
- _data.index_put_(indices, values)
- if inplace:
- return self
- obj = self.clone()
- obj._data = _data
- return obj
- def pad(self, padding_size: Tensor) -> "Keypoints":
- """Pad a bounding keypoints.
- Args:
- padding_size: (B, 4)
- """
- if not (len(padding_size.shape) == 2 and padding_size.size(1) == 4):
- raise RuntimeError(f"Expected padding_size as (B, 4). Got {padding_size.shape}.")
- self._data[..., 0] += padding_size[..., :1] # left padding
- self._data[..., 1] += padding_size[..., 2:3] # top padding
- return self
- def unpad(self, padding_size: Tensor) -> "Keypoints":
- """Pad a bounding keypoints.
- Args:
- padding_size: (B, 4)
- """
- if not (len(padding_size.shape) == 2 and padding_size.size(1) == 4):
- raise RuntimeError(f"Expected padding_size as (B, 4). Got {padding_size.shape}.")
- self._data[..., 0] -= padding_size[..., :1] # left padding
- self._data[..., 1] -= padding_size[..., 2:3] # top padding
- return self
- def transform_keypoints(self, M: Tensor, inplace: bool = False) -> "Keypoints":
- r"""Apply a transformation matrix to the 2D keypoints.
- Args:
- M: The transformation matrix to be applied, shape of :math:`(3, 3)` or :math:`(B, 3, 3)`.
- inplace: do transform in-place and return self.
- Returns:
- The transformed keypoints.
- """
- if not 2 <= M.ndim <= 3 or M.shape[-2:] != (3, 3):
- raise ValueError(f"The transformation matrix shape must be (3, 3) or (B, 3, 3). Got {M.shape}.")
- transformed_boxes = transform_points(M, self._data)
- if inplace:
- self._data = transformed_boxes
- return self
- return Keypoints(transformed_boxes, False)
- def transform_keypoints_(self, M: Tensor) -> "Keypoints":
- """Inplace version of :func:`Keypoints.transform_keypoints`."""
- return self.transform_keypoints(M, inplace=True)
- @classmethod
- def from_tensor(cls, keypoints: Tensor) -> "Keypoints":
- return cls(keypoints)
- def to_tensor(self, as_padded_sequence: bool = False) -> Union[Tensor, List[Tensor]]:
- r"""Cast :class:`Keypoints` to a tensor.
- ``mode`` controls which 2D keypoints format should be use to represent keypoints in the tensor.
- Args:
- as_padded_sequence: whether to keep the pads for a list of keypoints. This parameter is only valid
- if the keypoints are from a keypoint list.
- Returns:
- Keypoints tensor :math:`(B, N, 2)`
- """
- if as_padded_sequence:
- raise NotImplementedError
- return self._data
- def clone(self) -> "Keypoints":
- return Keypoints(self._data.clone(), False)
- def type(self, dtype: torch.dtype) -> "Keypoints":
- self._data = self._data.type(dtype)
- return self
- class VideoKeypoints(Keypoints):
- temporal_channel_size: int
- @classmethod
- def from_tensor(cls, boxes: Union[Tensor, List[Tensor]], validate_boxes: bool = True) -> "VideoKeypoints":
- if isinstance(boxes, (list,)) or (boxes.dim() != 4 or boxes.shape[-1] != 2):
- raise ValueError("Input box type is not yet supported. Please input an `BxTxNx2` tensor directly.")
- temporal_channel_size = boxes.size(1)
- # Due to some torch.jit.script bug (at least <= 1.9), you need to pass all arguments to __init__ when
- # constructing the class from inside of a method.
- out = cls(boxes.view(boxes.size(0) * boxes.size(1), -1, boxes.size(3)))
- out.temporal_channel_size = temporal_channel_size
- return out
- def to_tensor(self) -> Tensor: # type: ignore[override]
- out = super().to_tensor(as_padded_sequence=False)
- out = cast(Tensor, out)
- return out.view(-1, self.temporal_channel_size, *out.shape[1:])
- def transform_keypoints(self, M: Tensor, inplace: bool = False) -> "VideoKeypoints":
- out = super().transform_keypoints(M, inplace=inplace)
- if inplace:
- return self
- out = VideoKeypoints(out.data, False)
- out.temporal_channel_size = self.temporal_channel_size
- return out
- def clone(self) -> "VideoKeypoints":
- out = VideoKeypoints(self._data.clone(), False)
- out.temporal_channel_size = self.temporal_channel_size
- return out
- class Keypoints3D:
- """3D Keypoints containing Nx3 or BxNx3 points.
- Args:
- keypoints: Raw tensor or a list of Tensors with the Nx3 coordinates
- raise_if_not_floating_point: will raise if the Tensor isn't float
- """
- def __init__(self, keypoints: Union[Tensor, List[Tensor]], raise_if_not_floating_point: bool = True) -> None:
- self._N: Optional[List[int]] = None
- if isinstance(keypoints, list):
- keypoints, self._N = _merge_keypoint_list(keypoints)
- if not isinstance(keypoints, Tensor):
- raise TypeError(f"Input keypoints is not a Tensor. Got: {type(keypoints)}.")
- if not keypoints.is_floating_point():
- if raise_if_not_floating_point:
- raise ValueError(f"Coordinates must be in floating point. Got {keypoints.dtype}")
- keypoints = keypoints.float()
- if len(keypoints.shape) == 0:
- # Use reshape, so we don't end up creating a new tensor that does not depend on
- # the inputs (and consequently confuses jit)
- keypoints = keypoints.reshape((-1, 3))
- if not (2 <= keypoints.ndim <= 3 and keypoints.shape[-1:] == (3,)):
- raise ValueError(f"Keypoints shape must be (N, 3) or (B, N, 3). Got {keypoints.shape}.")
- self._is_batched = False if keypoints.ndim == 2 else True
- self._data = keypoints
- def __getitem__(self, key: Union[slice, int, Tensor]) -> "Keypoints3D":
- new_obj = type(self)(self._data[key], False)
- return new_obj
- def __setitem__(self, key: Union[slice, int, Tensor], value: "Keypoints3D") -> "Keypoints3D":
- self._data[key] = value._data
- return self
- @property
- def shape(self) -> Size:
- return self.data.shape
- @property
- def data(self) -> Tensor:
- return self._data
- def pad(self, padding_size: Tensor) -> "Keypoints3D":
- """Pad a bounding keypoints.
- Args:
- padding_size: (B, 6)
- """
- raise NotImplementedError
- def unpad(self, padding_size: Tensor) -> "Keypoints3D":
- """Pad a bounding keypoints.
- Args:
- padding_size: (B, 6)
- """
- raise NotImplementedError
- def transform_keypoints(self, M: Tensor, inplace: bool = False) -> "Keypoints3D":
- r"""Apply a transformation matrix to the 2D keypoints.
- Args:
- M: The transformation matrix to be applied, shape of :math:`(3, 3)` or :math:`(B, 3, 3)`.
- inplace: do transform in-place and return self.
- Returns:
- The transformed keypoints.
- """
- raise NotImplementedError
- def transform_keypoints_(self, M: Tensor) -> "Keypoints3D":
- """Inplace version of :func:`Keypoints.transform_keypoints`."""
- return self.transform_keypoints(M, inplace=True)
- @classmethod
- def from_tensor(cls, keypoints: Tensor) -> "Keypoints3D":
- return cls(keypoints)
- def to_tensor(self, as_padded_sequence: bool = False) -> Union[Tensor, List[Tensor]]:
- r"""Cast :class:`Keypoints` to a tensor.
- ``mode`` controls which 2D keypoints format should be use to represent keypoints in the tensor.
- Args:
- as_padded_sequence: whether to keep the pads for a list of keypoints. This parameter is only valid
- if the keypoints are from a keypoint list.
- Returns:
- Keypoints tensor :math:`(B, N, 3)`
- """
- if as_padded_sequence:
- raise NotImplementedError
- return self._data
- def clone(self) -> "Keypoints3D":
- return Keypoints3D(self._data.clone(), False)
|