se3.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # kornia.geometry.so3 module inspired by Sophus-sympy.
  18. # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/se3.py
  19. from __future__ import annotations
  20. from typing import Optional
  21. from kornia.core import (
  22. Device,
  23. Dtype,
  24. Module,
  25. Parameter,
  26. Tensor,
  27. concatenate,
  28. eye,
  29. pad,
  30. stack,
  31. tensor,
  32. where,
  33. zeros_like,
  34. )
  35. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SAME_DEVICES
  36. from kornia.geometry.liegroup.so3 import So3
  37. from kornia.geometry.linalg import batched_dot_product
  38. from kornia.geometry.quaternion import Quaternion
  39. from kornia.geometry.vector import Vector3
  40. class Se3(Module):
  41. r"""Base class to represent the Se3 group.
  42. The SE(3) is the group of rigid body transformations about the origin of three-dimensional Euclidean
  43. space :math:`R^3` under the operation of composition.
  44. See more: https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf
  45. Example:
  46. >>> q = Quaternion.identity()
  47. >>> s = Se3(q, torch.ones(3))
  48. >>> s.r
  49. tensor([1., 0., 0., 0.])
  50. >>> s.t
  51. Parameter containing:
  52. tensor([1., 1., 1.], requires_grad=True)
  53. """
  54. def __init__(self, rotation: Quaternion | So3, translation: Vector3 | Tensor) -> None:
  55. """Construct the base class.
  56. Internally represented by a unit quaternion `q` and a translation 3-vector.
  57. Args:
  58. rotation: So3 group encompassing a rotation.
  59. translation: Vector3 or translation tensor with the shape of :math:`(B, 3)`.
  60. Example:
  61. >>> from kornia.geometry.quaternion import Quaternion
  62. >>> q = Quaternion.identity(batch_size=1)
  63. >>> s = Se3(q, torch.ones((1, 3)))
  64. >>> s.r
  65. tensor([[1., 0., 0., 0.]])
  66. >>> s.t
  67. Parameter containing:
  68. tensor([[1., 1., 1.]], requires_grad=True)
  69. """
  70. super().__init__()
  71. # KORNIA_CHECK_TYPE(rotation, (Quaternion, So3))
  72. if not isinstance(rotation, (Quaternion, So3)):
  73. raise TypeError(f"rotation type is {type(rotation)}")
  74. # KORNIA_CHECK_TYPE(translation, (Vector3, Tensor))
  75. if not isinstance(translation, (Vector3, Tensor)):
  76. raise TypeError(f"translation type is {type(translation)}")
  77. # KORNIA_CHECK_SHAPE(t, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
  78. self._translation: Vector3 | Parameter
  79. self._rotation: So3
  80. if isinstance(translation, Tensor):
  81. self._translation = Parameter(translation)
  82. else:
  83. self._translation = translation
  84. if isinstance(rotation, Quaternion):
  85. self._rotation = So3(rotation)
  86. else:
  87. self._rotation = rotation
  88. def __repr__(self) -> str:
  89. return f"rotation: {self.r}\ntranslation: {self.t}"
  90. def __getitem__(self, idx: int | slice) -> Se3:
  91. return Se3(self._rotation[idx], self._translation[idx])
  92. def _mul_se3(self, right: Se3) -> Se3:
  93. _r = self.r * right.r
  94. _t = self.t + self.r * right.t
  95. return Se3(_r, _t)
  96. def __mul__(self, right: Se3) -> Se3 | Vector3 | Tensor:
  97. """Compose two Se3 transformations.
  98. Args:
  99. right: the other Se3 transformation.
  100. Return:
  101. The resulting Se3 transformation.
  102. """
  103. so3 = self.so3
  104. t = self.t
  105. if isinstance(right, Se3):
  106. # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/se3.py#L97
  107. return self._mul_se3(right)
  108. elif isinstance(right, (Vector3, Tensor)):
  109. # KORNIA_CHECK_SHAPE(right, ["B", "N"]) # FIXME: resolve shape bugs. @edgarriba
  110. return so3 * right + t.data
  111. else:
  112. raise TypeError(f"Unsupported type: {type(right)}")
  113. @property
  114. def so3(self) -> So3:
  115. """Return the underlying rotation(So3)."""
  116. return self._rotation
  117. @property
  118. def quaternion(self) -> Quaternion:
  119. """Return the underlying rotation(Quaternion)."""
  120. return self._rotation.q
  121. @property
  122. def r(self) -> So3:
  123. """Return the underlying rotation(So3)."""
  124. return self._rotation
  125. @property
  126. def t(self) -> Vector3 | Tensor:
  127. """Return the underlying translation vector of shape :math:`(B,3)`."""
  128. return self._translation
  129. @property
  130. def rotation(self) -> So3:
  131. """Return the underlying `rotation(So3)`."""
  132. return self._rotation
  133. @property
  134. def translation(self) -> Vector3 | Tensor:
  135. """Return the underlying translation vector of shape :math:`(B,3)`."""
  136. return self._translation
  137. @staticmethod
  138. def exp(v: Tensor) -> Se3:
  139. """Convert elements of lie algebra to elements of lie group.
  140. Args:
  141. v: vector of shape :math:`(B, 6)`.
  142. Example:
  143. >>> v = torch.zeros((1, 6))
  144. >>> s = Se3.exp(v)
  145. >>> s.r
  146. tensor([[1., 0., 0., 0.]])
  147. >>> s.t
  148. Parameter containing:
  149. tensor([[0., 0., 0.]], requires_grad=True)
  150. """
  151. # KORNIA_CHECK_SHAPE(v, ["B", "6"]) # FIXME: resolve shape bugs. @edgarriba
  152. upsilon = v[..., :3]
  153. omega = v[..., 3:]
  154. omega_hat = So3.hat(omega)
  155. omega_hat_sq = omega_hat @ omega_hat
  156. theta = batched_dot_product(omega, omega).sqrt()
  157. R = So3.exp(omega)
  158. V = (
  159. eye(3, device=v.device, dtype=v.dtype)
  160. + ((1 - theta.cos()) / (theta**2))[..., None, None] * omega_hat
  161. + ((theta - theta.sin()) / (theta**3))[..., None, None] * omega_hat_sq
  162. )
  163. U = where(theta[..., None] != 0.0, (upsilon[..., None, :] * V).sum(-1), upsilon)
  164. return Se3(R, U)
  165. def log(self) -> Tensor:
  166. """Convert elements of lie group to elements of lie algebra.
  167. Example:
  168. >>> from kornia.geometry.quaternion import Quaternion
  169. >>> q = Quaternion.identity()
  170. >>> Se3(q, torch.zeros(3)).log()
  171. tensor([0., 0., 0., 0., 0., 0.])
  172. """
  173. omega = self.r.log()
  174. theta = batched_dot_product(omega, omega).clamp_min(1e-12).sqrt()
  175. t = self.t.data
  176. omega_hat = So3.hat(omega)
  177. omega_hat_sq = omega_hat @ omega_hat
  178. V_inv = (
  179. eye(3, device=omega.device, dtype=omega.dtype)
  180. - 0.5 * omega_hat
  181. + ((1 - theta * (theta / 2).cos() / (2 * (theta / 2).sin())) / theta.pow(2))[..., None, None] * omega_hat_sq
  182. )
  183. t = where(theta[..., None] != 0.0, (t[..., None, :] * V_inv).sum(-1), t)
  184. return concatenate((t, omega), -1)
  185. @staticmethod
  186. def hat(v: Tensor) -> Tensor:
  187. """Convert elements from vector space to lie algebra.
  188. Args:
  189. v: vector of shape :math:`(B, 6)`.
  190. Returns:
  191. matrix of shape :math:`(B, 4, 4)`.
  192. Example:
  193. >>> v = torch.ones((1, 6))
  194. >>> m = Se3.hat(v)
  195. >>> m
  196. tensor([[[ 0., -1., 1., 1.],
  197. [ 1., 0., -1., 1.],
  198. [-1., 1., 0., 1.],
  199. [ 0., 0., 0., 0.]]])
  200. """
  201. # KORNIA_CHECK_SHAPE(v, ["B", "6"]) # FIXME: resolve shape bugs. @edgarriba
  202. upsilon, omega = v[..., :3], v[..., 3:]
  203. rt = concatenate((So3.hat(omega), upsilon[..., None]), -1)
  204. return pad(rt, (0, 0, 0, 1)) # add zeros bottom
  205. @staticmethod
  206. def vee(omega: Tensor) -> Tensor:
  207. """Convert elements from lie algebra to vector space.
  208. Args:
  209. omega: 4x4-matrix representing lie algebra of shape :math:`(B,4,4)`.
  210. Returns:
  211. vector of shape :math:`(B,6)`.
  212. Example:
  213. >>> v = torch.ones((1, 6))
  214. >>> omega_hat = Se3.hat(v)
  215. >>> Se3.vee(omega_hat)
  216. tensor([[1., 1., 1., 1., 1., 1.]])
  217. """
  218. # KORNIA_CHECK_SHAPE(omega, ["B", "4", "4"]) # FIXME: resolve shape bugs. @edgarriba
  219. head = omega[..., :3, -1]
  220. tail = So3.vee(omega[..., :3, :3])
  221. return concatenate((head, tail), -1)
  222. @classmethod
  223. def identity(cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None) -> Se3:
  224. """Create a Se3 group representing an identity rotation and zero translation.
  225. Args:
  226. batch_size: the batch size of the underlying data.
  227. device: device to place the result on.
  228. dtype: dtype of the result.
  229. Example:
  230. >>> s = Se3.identity()
  231. >>> s.r
  232. tensor([1., 0., 0., 0.])
  233. >>> s.t
  234. x: 0.0
  235. y: 0.0
  236. z: 0.0
  237. """
  238. t = tensor([0.0, 0.0, 0.0], device=device, dtype=dtype)
  239. if batch_size is not None:
  240. t = t.repeat(batch_size, 1)
  241. return cls(So3.identity(batch_size, device, dtype), Vector3(t))
  242. def matrix(self) -> Tensor:
  243. """Return the matrix representation of shape :math:`(B, 4, 4)`.
  244. Example:
  245. >>> s = Se3(So3.identity(), torch.ones(3))
  246. >>> s.matrix()
  247. tensor([[1., 0., 0., 1.],
  248. [0., 1., 0., 1.],
  249. [0., 0., 1., 1.],
  250. [0., 0., 0., 1.]])
  251. """
  252. rt = concatenate((self.r.matrix(), self.t.data[..., None]), -1)
  253. rt_4x4 = pad(rt, (0, 0, 0, 1)) # add last row zeros
  254. rt_4x4[..., -1, -1] = 1.0
  255. return rt_4x4
  256. @classmethod
  257. def from_matrix(cls, matrix: Tensor) -> Se3:
  258. """Create a Se3 group from a matrix.
  259. Args:
  260. matrix: tensor of shape :math:`(B, 4, 4)`.
  261. Example:
  262. >>> s = Se3.from_matrix(torch.eye(4))
  263. >>> s.r
  264. tensor([1., 0., 0., 0.])
  265. >>> s.t
  266. Parameter containing:
  267. tensor([0., 0., 0.], requires_grad=True)
  268. """
  269. # KORNIA_CHECK_SHAPE(matrix, ["B", "4", "4"]) # FIXME: resolve shape bugs. @edgarriba
  270. r = So3.from_matrix(matrix[..., :3, :3])
  271. t = matrix[..., :3, -1]
  272. return cls(r, t)
  273. @classmethod
  274. def from_qxyz(cls, qxyz: Tensor) -> Se3:
  275. """Create a Se3 group a quaternion and translation vector.
  276. Args:
  277. qxyz: tensor of shape :math:`(B, 7)`.
  278. Example:
  279. >>> qxyz = torch.tensor([1., 2., 3., 0., 0., 0., 1.])
  280. >>> s = Se3.from_qxyz(qxyz)
  281. >>> s.r
  282. tensor([1., 2., 3., 0.])
  283. >>> s.t
  284. x: 0.0
  285. y: 0.0
  286. z: 1.0
  287. """
  288. # KORNIA_CHECK_SHAPE(qxyz, ["B", "7"]) # FIXME: resolve shape bugs. @edgarriba
  289. q, xyz = qxyz[..., :4], qxyz[..., 4:]
  290. return cls(So3.from_wxyz(q), Vector3(xyz))
  291. def inverse(self) -> Se3:
  292. """Return the inverse transformation.
  293. Example:
  294. >>> s = Se3(So3.identity(), torch.ones(3))
  295. >>> s_inv = s.inverse()
  296. >>> s_inv.r
  297. tensor([1., -0., -0., -0.])
  298. >>> s_inv.t
  299. Parameter containing:
  300. tensor([-1., -1., -1.], requires_grad=True)
  301. """
  302. r_inv = self.r.inverse()
  303. _t = -1 * self.t
  304. if isinstance(_t, int):
  305. raise TypeError("Unexpected integer from `-1 * translation`")
  306. return Se3(r_inv, r_inv * _t)
  307. @classmethod
  308. def random(cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None) -> Se3:
  309. """Create a Se3 group representing a random transformation.
  310. Args:
  311. batch_size: the batch size of the underlying data.
  312. device: device to place the result on.
  313. dtype: dtype of the result.
  314. Example:
  315. >>> s = Se3.random()
  316. >>> s = Se3.random(batch_size=3)
  317. """
  318. shape: tuple[int, ...]
  319. if batch_size is None:
  320. shape = ()
  321. else:
  322. KORNIA_CHECK(batch_size >= 1, msg="batch_size must be positive")
  323. shape = (batch_size,)
  324. r = So3.random(batch_size, device, dtype)
  325. t = Vector3.random(shape, device, dtype)
  326. return cls(r, t)
  327. @classmethod
  328. def rot_x(cls, x: Tensor) -> Se3:
  329. """Construct a x-axis rotation.
  330. Args:
  331. x: the x-axis rotation angle.
  332. """
  333. zs = zeros_like(x)
  334. return cls(So3.rot_x(x), stack((zs, zs, zs), -1))
  335. @classmethod
  336. def rot_y(cls, y: Tensor) -> Se3:
  337. """Construct a y-axis rotation.
  338. Args:
  339. y: the y-axis rotation angle.
  340. """
  341. zs = zeros_like(y)
  342. return cls(So3.rot_y(y), stack((zs, zs, zs), -1))
  343. @classmethod
  344. def rot_z(cls, z: Tensor) -> Se3:
  345. """Construct a z-axis rotation.
  346. Args:
  347. z: the z-axis rotation angle.
  348. """
  349. zs = zeros_like(z)
  350. return cls(So3.rot_z(z), stack((zs, zs, zs), -1))
  351. @classmethod
  352. def trans(cls, x: Tensor, y: Tensor, z: Tensor) -> Se3:
  353. """Construct a translation only Se3 instance.
  354. Args:
  355. x: the x-axis translation.
  356. y: the y-axis translation.
  357. z: the z-axis translation.
  358. """
  359. KORNIA_CHECK(x.shape == y.shape)
  360. KORNIA_CHECK(y.shape == z.shape)
  361. KORNIA_CHECK_SAME_DEVICES([x, y, z])
  362. batch_size = x.shape[0] if len(x.shape) > 0 else None
  363. rotation = So3.identity(batch_size, x.device, x.dtype)
  364. return cls(rotation, stack((x, y, z), -1))
  365. @classmethod
  366. def trans_x(cls, x: Tensor) -> Se3:
  367. """Construct a x-axis translation.
  368. Args:
  369. x: the x-axis translation.
  370. """
  371. zs = zeros_like(x)
  372. return cls.trans(x, zs, zs)
  373. @classmethod
  374. def trans_y(cls, y: Tensor) -> Se3:
  375. """Construct a y-axis translation.
  376. Args:
  377. y: the y-axis translation.
  378. """
  379. zs = zeros_like(y)
  380. return cls.trans(zs, y, zs)
  381. @classmethod
  382. def trans_z(cls, z: Tensor) -> Se3:
  383. """Construct a z-axis translation.
  384. Args:
  385. z: the z-axis translation.
  386. """
  387. zs = zeros_like(z)
  388. return cls.trans(zs, zs, z)
  389. def adjoint(self) -> Tensor:
  390. """Return the adjoint matrix of shape :math:`(B, 6, 6)`.
  391. Example:
  392. >>> s = Se3.identity()
  393. >>> s.adjoint()
  394. tensor([[1., 0., 0., 0., 0., 0.],
  395. [0., 1., 0., 0., 0., 0.],
  396. [0., 0., 1., 0., 0., 0.],
  397. [0., 0., 0., 1., 0., 0.],
  398. [0., 0., 0., 0., 1., 0.],
  399. [0., 0., 0., 0., 0., 1.]])
  400. """
  401. R = self.so3.matrix()
  402. z = zeros_like(R)
  403. row0 = concatenate((R, So3.hat(self.t) @ R), -1)
  404. row1 = concatenate((z, R), -1)
  405. return concatenate((row0, row1), -2)