keypoints.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import List, Optional, Tuple, Union, cast
  18. import torch
  19. from torch import Size
  20. from kornia.core import Tensor
  21. from kornia.geometry import transform_points
  22. __all__ = ["Keypoints", "Keypoints3D"]
  23. def _merge_keypoint_list(keypoints: List[Tensor]) -> Tensor:
  24. raise NotImplementedError
  25. class Keypoints:
  26. """2D Keypoints containing Nx2 or BxNx2 points.
  27. Args:
  28. keypoints: Raw tensor or a list of Tensors with the Nx2 coordinates
  29. raise_if_not_floating_point: will raise if the Tensor isn't float
  30. """
  31. def __init__(self, keypoints: Union[Tensor, List[Tensor]], raise_if_not_floating_point: bool = True) -> None:
  32. self._N: Optional[List[int]] = None
  33. if isinstance(keypoints, list):
  34. keypoints, self._N = _merge_keypoint_list(keypoints)
  35. if not isinstance(keypoints, Tensor):
  36. raise TypeError(f"Input keypoints is not a Tensor. Got: {type(keypoints)}.")
  37. if not keypoints.is_floating_point():
  38. if raise_if_not_floating_point:
  39. raise ValueError(f"Coordinates must be in floating point. Got {keypoints.dtype}")
  40. keypoints = keypoints.float()
  41. if len(keypoints.shape) == 0:
  42. # Use reshape, so we don't end up creating a new tensor that does not depend on
  43. # the inputs (and consequently confuses jit)
  44. keypoints = keypoints.reshape((-1, 2))
  45. if not (2 <= keypoints.ndim <= 3 and keypoints.shape[-1:] == (2,)):
  46. raise ValueError(f"Keypoints shape must be (N, 2) or (B, N, 2). Got {keypoints.shape}.")
  47. self._is_batched = False if keypoints.ndim == 2 else True
  48. self._data = keypoints
  49. def __getitem__(self, key: Union[slice, int, Tensor]) -> "Keypoints":
  50. new_obj = type(self)(self._data[key], False)
  51. return new_obj
  52. def __setitem__(self, key: Union[slice, int, Tensor], value: "Keypoints") -> "Keypoints":
  53. self._data[key] = value._data
  54. return self
  55. @property
  56. def shape(self) -> Union[Tuple[int, ...], Size]:
  57. return self.data.shape
  58. @property
  59. def data(self) -> Tensor:
  60. return self._data
  61. @property
  62. def device(self) -> torch.device:
  63. """Returns keypoints device."""
  64. return self._data.device
  65. @property
  66. def dtype(self) -> torch.dtype:
  67. """Returns keypoints dtype."""
  68. return self._data.dtype
  69. def index_put(
  70. self,
  71. indices: Union[Tuple[Tensor, ...], List[Tensor]],
  72. values: Union[Tensor, "Keypoints"],
  73. inplace: bool = False,
  74. ) -> "Keypoints":
  75. if inplace:
  76. _data = self._data
  77. else:
  78. _data = self._data.clone()
  79. if isinstance(values, Keypoints):
  80. _data.index_put_(indices, values.data)
  81. else:
  82. _data.index_put_(indices, values)
  83. if inplace:
  84. return self
  85. obj = self.clone()
  86. obj._data = _data
  87. return obj
  88. def pad(self, padding_size: Tensor) -> "Keypoints":
  89. """Pad a bounding keypoints.
  90. Args:
  91. padding_size: (B, 4)
  92. """
  93. if not (len(padding_size.shape) == 2 and padding_size.size(1) == 4):
  94. raise RuntimeError(f"Expected padding_size as (B, 4). Got {padding_size.shape}.")
  95. self._data[..., 0] += padding_size[..., :1] # left padding
  96. self._data[..., 1] += padding_size[..., 2:3] # top padding
  97. return self
  98. def unpad(self, padding_size: Tensor) -> "Keypoints":
  99. """Pad a bounding keypoints.
  100. Args:
  101. padding_size: (B, 4)
  102. """
  103. if not (len(padding_size.shape) == 2 and padding_size.size(1) == 4):
  104. raise RuntimeError(f"Expected padding_size as (B, 4). Got {padding_size.shape}.")
  105. self._data[..., 0] -= padding_size[..., :1] # left padding
  106. self._data[..., 1] -= padding_size[..., 2:3] # top padding
  107. return self
  108. def transform_keypoints(self, M: Tensor, inplace: bool = False) -> "Keypoints":
  109. r"""Apply a transformation matrix to the 2D keypoints.
  110. Args:
  111. M: The transformation matrix to be applied, shape of :math:`(3, 3)` or :math:`(B, 3, 3)`.
  112. inplace: do transform in-place and return self.
  113. Returns:
  114. The transformed keypoints.
  115. """
  116. if not 2 <= M.ndim <= 3 or M.shape[-2:] != (3, 3):
  117. raise ValueError(f"The transformation matrix shape must be (3, 3) or (B, 3, 3). Got {M.shape}.")
  118. transformed_boxes = transform_points(M, self._data)
  119. if inplace:
  120. self._data = transformed_boxes
  121. return self
  122. return Keypoints(transformed_boxes, False)
  123. def transform_keypoints_(self, M: Tensor) -> "Keypoints":
  124. """Inplace version of :func:`Keypoints.transform_keypoints`."""
  125. return self.transform_keypoints(M, inplace=True)
  126. @classmethod
  127. def from_tensor(cls, keypoints: Tensor) -> "Keypoints":
  128. return cls(keypoints)
  129. def to_tensor(self, as_padded_sequence: bool = False) -> Union[Tensor, List[Tensor]]:
  130. r"""Cast :class:`Keypoints` to a tensor.
  131. ``mode`` controls which 2D keypoints format should be use to represent keypoints in the tensor.
  132. Args:
  133. as_padded_sequence: whether to keep the pads for a list of keypoints. This parameter is only valid
  134. if the keypoints are from a keypoint list.
  135. Returns:
  136. Keypoints tensor :math:`(B, N, 2)`
  137. """
  138. if as_padded_sequence:
  139. raise NotImplementedError
  140. return self._data
  141. def clone(self) -> "Keypoints":
  142. return Keypoints(self._data.clone(), False)
  143. def type(self, dtype: torch.dtype) -> "Keypoints":
  144. self._data = self._data.type(dtype)
  145. return self
  146. class VideoKeypoints(Keypoints):
  147. temporal_channel_size: int
  148. @classmethod
  149. def from_tensor(cls, boxes: Union[Tensor, List[Tensor]], validate_boxes: bool = True) -> "VideoKeypoints":
  150. if isinstance(boxes, (list,)) or (boxes.dim() != 4 or boxes.shape[-1] != 2):
  151. raise ValueError("Input box type is not yet supported. Please input an `BxTxNx2` tensor directly.")
  152. temporal_channel_size = boxes.size(1)
  153. # Due to some torch.jit.script bug (at least <= 1.9), you need to pass all arguments to __init__ when
  154. # constructing the class from inside of a method.
  155. out = cls(boxes.view(boxes.size(0) * boxes.size(1), -1, boxes.size(3)))
  156. out.temporal_channel_size = temporal_channel_size
  157. return out
  158. def to_tensor(self) -> Tensor: # type: ignore[override]
  159. out = super().to_tensor(as_padded_sequence=False)
  160. out = cast(Tensor, out)
  161. return out.view(-1, self.temporal_channel_size, *out.shape[1:])
  162. def transform_keypoints(self, M: Tensor, inplace: bool = False) -> "VideoKeypoints":
  163. out = super().transform_keypoints(M, inplace=inplace)
  164. if inplace:
  165. return self
  166. out = VideoKeypoints(out.data, False)
  167. out.temporal_channel_size = self.temporal_channel_size
  168. return out
  169. def clone(self) -> "VideoKeypoints":
  170. out = VideoKeypoints(self._data.clone(), False)
  171. out.temporal_channel_size = self.temporal_channel_size
  172. return out
  173. class Keypoints3D:
  174. """3D Keypoints containing Nx3 or BxNx3 points.
  175. Args:
  176. keypoints: Raw tensor or a list of Tensors with the Nx3 coordinates
  177. raise_if_not_floating_point: will raise if the Tensor isn't float
  178. """
  179. def __init__(self, keypoints: Union[Tensor, List[Tensor]], raise_if_not_floating_point: bool = True) -> None:
  180. self._N: Optional[List[int]] = None
  181. if isinstance(keypoints, list):
  182. keypoints, self._N = _merge_keypoint_list(keypoints)
  183. if not isinstance(keypoints, Tensor):
  184. raise TypeError(f"Input keypoints is not a Tensor. Got: {type(keypoints)}.")
  185. if not keypoints.is_floating_point():
  186. if raise_if_not_floating_point:
  187. raise ValueError(f"Coordinates must be in floating point. Got {keypoints.dtype}")
  188. keypoints = keypoints.float()
  189. if len(keypoints.shape) == 0:
  190. # Use reshape, so we don't end up creating a new tensor that does not depend on
  191. # the inputs (and consequently confuses jit)
  192. keypoints = keypoints.reshape((-1, 3))
  193. if not (2 <= keypoints.ndim <= 3 and keypoints.shape[-1:] == (3,)):
  194. raise ValueError(f"Keypoints shape must be (N, 3) or (B, N, 3). Got {keypoints.shape}.")
  195. self._is_batched = False if keypoints.ndim == 2 else True
  196. self._data = keypoints
  197. def __getitem__(self, key: Union[slice, int, Tensor]) -> "Keypoints3D":
  198. new_obj = type(self)(self._data[key], False)
  199. return new_obj
  200. def __setitem__(self, key: Union[slice, int, Tensor], value: "Keypoints3D") -> "Keypoints3D":
  201. self._data[key] = value._data
  202. return self
  203. @property
  204. def shape(self) -> Size:
  205. return self.data.shape
  206. @property
  207. def data(self) -> Tensor:
  208. return self._data
  209. def pad(self, padding_size: Tensor) -> "Keypoints3D":
  210. """Pad a bounding keypoints.
  211. Args:
  212. padding_size: (B, 6)
  213. """
  214. raise NotImplementedError
  215. def unpad(self, padding_size: Tensor) -> "Keypoints3D":
  216. """Pad a bounding keypoints.
  217. Args:
  218. padding_size: (B, 6)
  219. """
  220. raise NotImplementedError
  221. def transform_keypoints(self, M: Tensor, inplace: bool = False) -> "Keypoints3D":
  222. r"""Apply a transformation matrix to the 2D keypoints.
  223. Args:
  224. M: The transformation matrix to be applied, shape of :math:`(3, 3)` or :math:`(B, 3, 3)`.
  225. inplace: do transform in-place and return self.
  226. Returns:
  227. The transformed keypoints.
  228. """
  229. raise NotImplementedError
  230. def transform_keypoints_(self, M: Tensor) -> "Keypoints3D":
  231. """Inplace version of :func:`Keypoints.transform_keypoints`."""
  232. return self.transform_keypoints(M, inplace=True)
  233. @classmethod
  234. def from_tensor(cls, keypoints: Tensor) -> "Keypoints3D":
  235. return cls(keypoints)
  236. def to_tensor(self, as_padded_sequence: bool = False) -> Union[Tensor, List[Tensor]]:
  237. r"""Cast :class:`Keypoints` to a tensor.
  238. ``mode`` controls which 2D keypoints format should be use to represent keypoints in the tensor.
  239. Args:
  240. as_padded_sequence: whether to keep the pads for a list of keypoints. This parameter is only valid
  241. if the keypoints are from a keypoint list.
  242. Returns:
  243. Keypoints tensor :math:`(B, N, 3)`
  244. """
  245. if as_padded_sequence:
  246. raise NotImplementedError
  247. return self._data
  248. def clone(self) -> "Keypoints3D":
  249. return Keypoints3D(self._data.clone(), False)