so3.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # kornia.geometry.so3 module inspired by Sophus-sympy.
  18. # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/so3.py
  19. from __future__ import annotations
  20. from typing import Optional
  21. import torch
  22. from kornia.core import Device, Dtype, Module, Tensor, concatenate, eye, stack, where, zeros, zeros_like
  23. from kornia.core.check import KORNIA_CHECK_TYPE
  24. from kornia.geometry.conversions import vector_to_skew_symmetric_matrix
  25. from kornia.geometry.linalg import batched_dot_product
  26. from kornia.geometry.quaternion import Quaternion
  27. from kornia.geometry.vector import Vector3
  28. class So3(Module):
  29. r"""Base class to represent the So3 group.
  30. The SO(3) is the group of all rotations about the origin of three-dimensional Euclidean space
  31. :math:`R^3` under the operation of composition.
  32. See more: https://en.wikipedia.org/wiki/3D_rotation_group
  33. We internally represent the rotation by a unit quaternion.
  34. Example:
  35. >>> q = Quaternion.identity()
  36. >>> s = So3(q)
  37. >>> s.q
  38. tensor([1., 0., 0., 0.])
  39. """
  40. def __init__(self, q: Quaternion) -> None:
  41. """Construct the base class.
  42. Internally represented by a unit quaternion `q`.
  43. Args:
  44. q: Quaternion with the shape of :math:`(B, 4)`.
  45. Example:
  46. >>> data = torch.ones((2, 4))
  47. >>> q = Quaternion(data)
  48. >>> So3(q)
  49. tensor([[1., 1., 1., 1.],
  50. [1., 1., 1., 1.]])
  51. """
  52. super().__init__()
  53. KORNIA_CHECK_TYPE(q, Quaternion)
  54. self._q = q
  55. def __repr__(self) -> str:
  56. return f"{self.q}"
  57. def __getitem__(self, idx: int | slice) -> So3:
  58. return So3(self._q[idx])
  59. def __mul__(self, right: So3) -> So3:
  60. """Compose two So3 transformations.
  61. Args:
  62. right: the other So3 transformation.
  63. Return:
  64. The resulting So3 transformation.
  65. """
  66. # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/so3.py#L98
  67. if isinstance(right, So3):
  68. return So3(self.q * right.q)
  69. elif isinstance(right, (Tensor, Vector3)):
  70. # KORNIA_CHECK_SHAPE(right, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
  71. w = zeros(*right.shape[:-1], 1, device=right.device, dtype=right.dtype)
  72. quat = Quaternion(concatenate((w, right.data), -1))
  73. out = (self.q * quat * self.q.conj()).vec
  74. if isinstance(right, Tensor):
  75. return out
  76. elif isinstance(right, Vector3):
  77. return Vector3(out)
  78. else:
  79. raise TypeError(f"Not So3 or Tensor type. Got: {type(right)}")
  80. @property
  81. def q(self) -> Quaternion:
  82. """Return the underlying data with shape :math:`(B,4)`."""
  83. return self._q
  84. @staticmethod
  85. def exp(v: Tensor) -> So3:
  86. """Convert elements of lie algebra to elements of lie group.
  87. See more: https://vision.in.tum.de/_media/members/demmeln/nurlanov2021so3log.pdf
  88. Args:
  89. v: vector of shape :math:`(B,3)`.
  90. Example:
  91. >>> v = torch.zeros((2, 3))
  92. >>> s = So3.exp(v)
  93. >>> s
  94. tensor([[1., 0., 0., 0.],
  95. [1., 0., 0., 0.]])
  96. """
  97. # KORNIA_CHECK_SHAPE(v, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
  98. theta = v.norm(dim=-1, keepdim=True)
  99. theta_half = 0.5 * theta
  100. w = torch.cos(theta_half)
  101. eps = torch.finfo(v.dtype).eps * 1e3
  102. small_mask = theta <= eps
  103. b_large = torch.sin(theta_half) / theta
  104. b_small = 0.5 - (theta * theta) / 48.0
  105. b = torch.where(small_mask, b_small, b_large)
  106. xyz = b * v
  107. q = torch.cat((w, xyz), dim=-1)
  108. return So3(Quaternion(q))
  109. def log(self) -> Tensor:
  110. """Convert elements of lie group to elements of lie algebra.
  111. Example:
  112. >>> data = torch.ones((2, 4))
  113. >>> q = Quaternion(data)
  114. >>> So3(q).log()
  115. tensor([[0., 0., 0.],
  116. [0., 0., 0.]])
  117. """
  118. theta = batched_dot_product(self.q.vec, self.q.vec).sqrt()
  119. # NOTE: this differs from https://github.com/strasdat/Sophus/blob/master/sympy/sophus/so3.py#L33
  120. omega = where(
  121. theta[..., None] != 0,
  122. 2 * self.q.real[..., None].acos() * self.q.vec / theta[..., None],
  123. 2 * self.q.vec / self.q.real[..., None],
  124. )
  125. return omega
  126. @staticmethod
  127. def hat(v: Vector3 | Tensor) -> Tensor:
  128. """Convert elements from vector space to lie algebra. Returns matrix of shape :math:`(B,3,3)`.
  129. Args:
  130. v: Vector3 or tensor of shape :math:`(B,3)`.
  131. Example:
  132. >>> v = torch.ones((1,3))
  133. >>> m = So3.hat(v)
  134. >>> m
  135. tensor([[[ 0., -1., 1.],
  136. [ 1., 0., -1.],
  137. [-1., 1., 0.]]])
  138. """
  139. # KORNIA_CHECK_SHAPE(v, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
  140. if isinstance(v, Tensor):
  141. # TODO: Figure out why mypy think `v` can be a Vector3 which didn't allow ellipsis on index
  142. a, b, c = v[..., 0], v[..., 1], v[..., 2] # type: ignore[index]
  143. else:
  144. a, b, c = v.x, v.y, v.z
  145. z = zeros_like(a)
  146. row0 = stack((z, -c, b), -1)
  147. row1 = stack((c, z, -a), -1)
  148. row2 = stack((-b, a, z), -1)
  149. return stack((row0, row1, row2), -2)
  150. @staticmethod
  151. def vee(omega: Tensor) -> Tensor:
  152. r"""Convert elements from lie algebra to vector space. Returns vector of shape :math:`(B,3)`.
  153. .. math::
  154. omega = \begin{bmatrix} 0 & -c & b \\
  155. c & 0 & -a \\
  156. -b & a & 0\end{bmatrix}
  157. Args:
  158. omega: 3x3-matrix representing lie algebra.
  159. Example:
  160. >>> v = torch.ones((1,3))
  161. >>> omega = So3.hat(v)
  162. >>> So3.vee(omega)
  163. tensor([[1., 1., 1.]])
  164. """
  165. # KORNIA_CHECK_SHAPE(omega, ["B", "3", "3"]) # FIXME: resolve shape bugs. @edgarriba
  166. a, b, c = omega[..., 2, 1], omega[..., 0, 2], omega[..., 1, 0]
  167. return stack((a, b, c), -1)
  168. def matrix(self) -> Tensor:
  169. r"""Convert the quaternion to a rotation matrix of shape :math:`(B,3,3)`.
  170. The matrix is of the form:
  171. .. math::
  172. \begin{bmatrix} 1-2y^2-2z^2 & 2xy-2zw & 2xy+2yw \\
  173. 2xy+2zw & 1-2x^2-2z^2 & 2yz-2xw \\
  174. 2xz-2yw & 2yz+2xw & 1-2x^2-2y^2\end{bmatrix}
  175. Example:
  176. >>> s = So3.identity()
  177. >>> m = s.matrix()
  178. >>> m
  179. tensor([[1., 0., 0.],
  180. [0., 1., 0.],
  181. [0., 0., 1.]])
  182. """
  183. w = self.q.w[..., None]
  184. x, y, z = self.q.x[..., None], self.q.y[..., None], self.q.z[..., None]
  185. q0 = 1 - 2 * y**2 - 2 * z**2
  186. q1 = 2 * x * y - 2 * z * w
  187. q2 = 2 * x * z + 2 * y * w
  188. row0 = concatenate((q0, q1, q2), -1)
  189. q0 = 2 * x * y + 2 * z * w
  190. q1 = 1 - 2 * x**2 - 2 * z**2
  191. q2 = 2 * y * z - 2 * x * w
  192. row1 = concatenate((q0, q1, q2), -1)
  193. q0 = 2 * x * z - 2 * y * w
  194. q1 = 2 * y * z + 2 * x * w
  195. q2 = 1 - 2 * x**2 - 2 * y**2
  196. row2 = concatenate((q0, q1, q2), -1)
  197. return stack((row0, row1, row2), -2)
  198. @classmethod
  199. def from_matrix(cls, matrix: Tensor) -> So3:
  200. """Create So3 from a rotation matrix.
  201. Args:
  202. matrix: the rotation matrix to convert of shape :math:`(B,3,3)`.
  203. Example:
  204. >>> m = torch.eye(3)
  205. >>> s = So3.from_matrix(m)
  206. >>> s
  207. tensor([1., 0., 0., 0.])
  208. """
  209. return cls(Quaternion.from_matrix(matrix))
  210. @classmethod
  211. def from_wxyz(cls, wxyz: Tensor) -> So3:
  212. """Create So3 from a tensor representing a quaternion.
  213. Args:
  214. wxyz: the quaternion to convert of shape :math:`(B,4)`.
  215. Example:
  216. >>> q = torch.tensor([1., 0., 0., 0.])
  217. >>> s = So3.from_wxyz(q)
  218. >>> s
  219. tensor([1., 0., 0., 0.])
  220. """
  221. # KORNIA_CHECK_SHAPE(wxyz, ["B", "4"]) # FIXME: resolve shape bugs. @edgarriba
  222. return cls(Quaternion(wxyz))
  223. @classmethod
  224. def identity(
  225. cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Optional[Dtype] = None
  226. ) -> So3:
  227. """Create a So3 group representing an identity rotation.
  228. Args:
  229. batch_size: the batch size of the underlying data.
  230. device: device to place the result on.
  231. dtype: dtype of the result.
  232. Example:
  233. >>> s = So3.identity()
  234. >>> s
  235. tensor([1., 0., 0., 0.])
  236. >>> s = So3.identity(batch_size=2)
  237. >>> s
  238. tensor([[1., 0., 0., 0.],
  239. [1., 0., 0., 0.]])
  240. """
  241. return cls(Quaternion.identity(batch_size, device, dtype))
  242. def inverse(self) -> So3:
  243. """Return the inverse transformation.
  244. Example:
  245. >>> s = So3.identity()
  246. >>> s.inverse()
  247. tensor([1., -0., -0., -0.])
  248. """
  249. return So3(self.q.conj())
  250. @classmethod
  251. def random(
  252. cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Optional[Dtype] = None
  253. ) -> So3:
  254. """Create a So3 group representing a random rotation.
  255. Args:
  256. batch_size: the batch size of the underlying data.
  257. device: device to place the result on.
  258. dtype: dtype of the result.
  259. Example:
  260. >>> s = So3.random()
  261. >>> s = So3.random(batch_size=3)
  262. """
  263. return cls(Quaternion.random(batch_size, device, dtype))
  264. @classmethod
  265. def rot_x(cls, x: Tensor) -> So3:
  266. """Construct a x-axis rotation.
  267. Args:
  268. x: the x-axis rotation angle.
  269. """
  270. zs = zeros_like(x)
  271. return cls.exp(stack((x, zs, zs), -1))
  272. @classmethod
  273. def rot_y(cls, y: Tensor) -> So3:
  274. """Construct a z-axis rotation.
  275. Args:
  276. y: the y-axis rotation angle.
  277. """
  278. zs = zeros_like(y)
  279. return cls.exp(stack((zs, y, zs), -1))
  280. @classmethod
  281. def rot_z(cls, z: Tensor) -> So3:
  282. """Construct a z-axis rotation.
  283. Args:
  284. z: the z-axis rotation angle.
  285. """
  286. zs = zeros_like(z)
  287. return cls.exp(stack((zs, zs, z), -1))
  288. def adjoint(self) -> Tensor:
  289. """Return the adjoint matrix of shape :math:`(B, 3, 3)`.
  290. Example:
  291. >>> s = So3.identity()
  292. >>> s.adjoint()
  293. tensor([[1., 0., 0.],
  294. [0., 1., 0.],
  295. [0., 0., 1.]])
  296. """
  297. return self.matrix()
  298. @staticmethod
  299. def right_jacobian(vec: Tensor) -> Tensor:
  300. """Compute the right Jacobian of So3.
  301. Args:
  302. vec: the input point of shape :math:`(B, 3)`.
  303. Example:
  304. >>> vec = torch.tensor([1., 2., 3.])
  305. >>> So3.right_jacobian(vec)
  306. tensor([[-0.0687, 0.5556, -0.0141],
  307. [-0.2267, 0.1779, 0.6236],
  308. [ 0.5074, 0.3629, 0.5890]])
  309. """
  310. # KORNIA_CHECK_SHAPE(vec, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
  311. R_skew = vector_to_skew_symmetric_matrix(vec)
  312. theta = vec.norm(dim=-1, keepdim=True)[..., None]
  313. I = eye(3, device=vec.device, dtype=vec.dtype) # noqa: E741
  314. Jr = I - ((1 - theta.cos()) / theta**2) * R_skew + ((theta - theta.sin()) / theta**3) * (R_skew @ R_skew)
  315. return Jr
  316. @staticmethod
  317. def Jr(vec: Tensor) -> Tensor:
  318. """Alias for right jacobian.
  319. Args:
  320. vec: the input point of shape :math:`(B, 3)`.
  321. """
  322. return So3.right_jacobian(vec)
  323. @staticmethod
  324. def left_jacobian(vec: Tensor) -> Tensor:
  325. """Compute the left Jacobian of So3.
  326. Args:
  327. vec: the input point of shape :math:`(B, 3)`.
  328. Example:
  329. >>> vec = torch.tensor([1., 2., 3.])
  330. >>> So3.left_jacobian(vec)
  331. tensor([[-0.0687, -0.2267, 0.5074],
  332. [ 0.5556, 0.1779, 0.3629],
  333. [-0.0141, 0.6236, 0.5890]])
  334. """
  335. # KORNIA_CHECK_SHAPE(vec, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
  336. R_skew = vector_to_skew_symmetric_matrix(vec)
  337. theta = vec.norm(dim=-1, keepdim=True)[..., None]
  338. I = eye(3, device=vec.device, dtype=vec.dtype) # noqa: E741
  339. Jl = I + ((1 - theta.cos()) / theta**2) * R_skew + ((theta - theta.sin()) / theta**3) * (R_skew @ R_skew)
  340. return Jl
  341. @staticmethod
  342. def Jl(vec: Tensor) -> Tensor:
  343. """Alias for left jacobian.
  344. Args:
  345. vec: the input point of shape :math:`(B, 3)`.
  346. """
  347. return So3.left_jacobian(vec)