se2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # kornia.geometry.se2 module inspired by Sophus-sympy.
  18. # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/se2.py
  19. from __future__ import annotations
  20. from typing import Optional, overload
  21. from kornia.core import (
  22. Device,
  23. Dtype,
  24. Module,
  25. Parameter,
  26. Tensor,
  27. concatenate,
  28. pad,
  29. rand,
  30. stack,
  31. tensor,
  32. where,
  33. zeros_like,
  34. )
  35. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SAME_DEVICES, KORNIA_CHECK_TYPE
  36. from kornia.geometry.liegroup._utils import check_se2_omega_shape, check_se2_t_shape, check_v_shape
  37. from kornia.geometry.liegroup.so2 import So2
  38. from kornia.geometry.vector import Vector2
  39. def _check_se2_r_t_shape(r: So2, t: Tensor) -> None:
  40. z_shape = r.z.shape
  41. if ((len(z_shape) == 1) and (len(t.shape) == 2)) or ((len(z_shape) == 0) and len(t.shape) == 1):
  42. check_se2_t_shape(t)
  43. else:
  44. raise ValueError(
  45. f"Invalid input, both the inputs should be either batched or unbatched. Got: {r.z.shape} and {t.shape}"
  46. )
  47. class Se2(Module):
  48. r"""Base class to represent the Se2 group.
  49. The SE(2) is the group of rigid body transformations about the origin of two-dimensional Euclidean
  50. space :math:`R^2` under the operation of composition.
  51. See more:
  52. Example:
  53. >>> so2 = So2.identity(1)
  54. >>> t = torch.ones((1, 2))
  55. >>> se2 = Se2(so2, t)
  56. >>> se2
  57. rotation: Parameter containing:
  58. tensor([1.+0.j], requires_grad=True)
  59. translation: Parameter containing:
  60. tensor([[1., 1.]], requires_grad=True)
  61. """
  62. def __init__(self, rotation: So2, translation: Vector2 | Tensor) -> None:
  63. """Construct the base class.
  64. Internally represented by a complex number `z` and a translation 2-vector.
  65. Args:
  66. rotation: So2 group encompassing a rotation.
  67. translation: translation vector with the shape of :math:`(B, 2)`.
  68. Example:
  69. >>> so2 = So2.identity(1)
  70. >>> t = torch.ones((1, 2))
  71. >>> se2 = Se2(so2, t)
  72. >>> se2
  73. rotation: Parameter containing:
  74. tensor([1.+0.j], requires_grad=True)
  75. translation: Parameter containing:
  76. tensor([[1., 1.]], requires_grad=True)
  77. """
  78. super().__init__()
  79. KORNIA_CHECK_TYPE(rotation, So2)
  80. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  81. # KORNIA_CHECK_TYPE(translation, (Vector3, Tensor))
  82. if not isinstance(translation, (Vector2, Tensor)):
  83. raise TypeError(f"translation type is {type(translation)}")
  84. self._translation: Vector2 | Parameter
  85. self._rotation: So2 = rotation
  86. if isinstance(translation, Tensor):
  87. _check_se2_r_t_shape(rotation, translation) # TODO remove
  88. self._translation = Parameter(translation)
  89. else:
  90. self._translation = translation
  91. def __repr__(self) -> str:
  92. return f"rotation: {self.r}\ntranslation: {self.t}"
  93. def __getitem__(self, idx: int | slice) -> Se2:
  94. return Se2(self._rotation[idx], self._translation[idx])
  95. def _mul_se2(self, right: Se2) -> Se2:
  96. so2 = self.so2
  97. t = self.t
  98. _r = so2 * right.so2
  99. _t = t + so2 * right.t
  100. return Se2(_r, _t)
  101. @overload
  102. def __mul__(self, right: Se2) -> Se2: ...
  103. @overload
  104. def __mul__(self, right: Tensor) -> Tensor: ...
  105. def __mul__(self, right: Se2 | Tensor) -> Se2 | Tensor:
  106. """Compose two Se2 transformations.
  107. Args:
  108. right: the other Se2 transformation.
  109. Return:
  110. The resulting Se2 transformation.
  111. """
  112. so2 = self.so2
  113. t = self.t
  114. if isinstance(right, Se2):
  115. KORNIA_CHECK_TYPE(right, Se2)
  116. return self._mul_se2(right)
  117. elif isinstance(right, (Vector2, Tensor)):
  118. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  119. # _check_se2_r_t_shape(so2, risght)
  120. return so2 * right + t
  121. else:
  122. raise TypeError(f"Unsupported type: {type(right)}")
  123. @property
  124. def so2(self) -> So2:
  125. """Return the underlying `rotation(So2)`."""
  126. return self._rotation
  127. @property
  128. def r(self) -> So2:
  129. """Return the underlying `rotation(So2)`."""
  130. return self._rotation
  131. @property
  132. def t(self) -> Vector2 | Parameter:
  133. """Return the underlying translation vector of shape :math:`(B,2)`."""
  134. return self._translation
  135. @property
  136. def rotation(self) -> So2:
  137. """Return the underlying `rotation(So2)`."""
  138. return self._rotation
  139. @property
  140. def translation(self) -> Vector2 | Parameter:
  141. """Return the underlying translation vector of shape :math:`(B,2)`."""
  142. return self._translation
  143. @staticmethod
  144. def exp(v: Tensor) -> Se2:
  145. """Convert elements of lie algebra to elements of lie group.
  146. Args:
  147. v: vector of shape :math:`(B, 3)`.
  148. Example:
  149. >>> v = torch.ones((1, 3))
  150. >>> s = Se2.exp(v)
  151. >>> s.r
  152. Parameter containing:
  153. tensor([0.5403+0.8415j], requires_grad=True)
  154. >>> s.t
  155. Parameter containing:
  156. tensor([[0.3818, 1.3012]], requires_grad=True)
  157. """
  158. check_v_shape(v)
  159. theta = v[..., 2]
  160. so2 = So2.exp(theta)
  161. z = tensor(0.0, device=v.device, dtype=v.dtype)
  162. theta_nonzeros = theta != 0.0
  163. a = where(theta_nonzeros, so2.z.imag / theta, z)
  164. b = where(theta_nonzeros, (1.0 - so2.z.real) / theta, z)
  165. x = v[..., 0]
  166. y = v[..., 1]
  167. t = stack((a * x - b * y, b * x + a * y), -1)
  168. return Se2(so2, t)
  169. def log(self) -> Tensor:
  170. """Convert elements of lie group to elements of lie algebra.
  171. Example:
  172. >>> v = torch.ones((1, 3))
  173. >>> s = Se2.exp(v).log()
  174. >>> s
  175. tensor([[1.0000, 1.0000, 1.0000]], grad_fn=<StackBackward0>)
  176. """
  177. theta = self.so2.log()
  178. half_theta = 0.5 * theta
  179. denom = self.so2.z.real - 1
  180. a = where(
  181. denom != 0, -(half_theta * self.so2.z.imag) / denom, tensor(0.0, device=theta.device, dtype=theta.dtype)
  182. )
  183. row0 = stack((a, half_theta), -1)
  184. row1 = stack((-half_theta, a), -1)
  185. V_inv = stack((row0, row1), -2)
  186. upsilon = V_inv @ self.t.data[..., None]
  187. return stack((upsilon[..., 0, 0], upsilon[..., 1, 0], theta), -1)
  188. @staticmethod
  189. def hat(v: Tensor) -> Tensor:
  190. """Convert elements from vector space to lie algebra. Returns matrix of shape :math:`(B, 3, 3)`.
  191. Args:
  192. v: vector of shape:math:`(B, 3)`.
  193. Example:
  194. >>> theta = torch.tensor(3.1415/2)
  195. >>> So2.hat(theta)
  196. tensor([[0.0000, 1.5707],
  197. [1.5707, 0.0000]])
  198. """
  199. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  200. check_v_shape(v)
  201. upsilon = stack((v[..., 0], v[..., 1]), -1)
  202. theta = v[..., 2]
  203. col0 = concatenate((So2.hat(theta), upsilon.unsqueeze(-2)), -2)
  204. return pad(col0, (0, 1))
  205. @staticmethod
  206. def vee(omega: Tensor) -> Tensor:
  207. """Convert elements from lie algebra to vector space.
  208. Args:
  209. omega: 3x3-matrix representing lie algebra of shape :math:`(B, 3, 3)`.
  210. Returns:
  211. vector of shape :math:`(B, 3)`.
  212. Example:
  213. >>> v = torch.ones(3)
  214. >>> omega_hat = Se2.hat(v)
  215. >>> Se2.vee(omega_hat)
  216. tensor([1., 1., 1.])
  217. """
  218. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  219. check_se2_omega_shape(omega)
  220. upsilon = omega[..., 2, :2]
  221. theta = So2.vee(omega[..., :2, :2])
  222. return concatenate((upsilon, theta[..., None]), -1)
  223. @classmethod
  224. def identity(cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None) -> Se2:
  225. """Create a Se2 group representing an identity rotation and zero translation.
  226. Args:
  227. batch_size: the batch size of the underlying data.
  228. device: device to place the result on.
  229. dtype: dtype of the result.
  230. Example:
  231. >>> s = Se2.identity(1)
  232. >>> s.r
  233. Parameter containing:
  234. tensor([1.+0.j], requires_grad=True)
  235. >>> s.t
  236. x: tensor([0.])
  237. y: tensor([0.])
  238. """
  239. t: Tensor = tensor([0.0, 0.0], device=device, dtype=dtype)
  240. if batch_size is not None:
  241. KORNIA_CHECK(batch_size >= 1, msg="batch_size must be positive")
  242. t = t.repeat(batch_size, 1)
  243. return cls(So2.identity(batch_size, device, dtype), Vector2(t))
  244. def matrix(self) -> Tensor:
  245. """Return the matrix representation of shape :math:`(B, 3, 3)`.
  246. Example:
  247. >>> s = Se2(So2.identity(1), torch.ones(1, 2))
  248. >>> s.matrix()
  249. tensor([[[1., -0., 1.],
  250. [0., 1., 1.],
  251. [0., 0., 1.]]], grad_fn=<CopySlices>)
  252. """
  253. rt = concatenate((self.r.matrix(), self.t.data[..., None]), -1)
  254. rt_3x3 = pad(rt, (0, 0, 0, 1)) # add last row zeros
  255. rt_3x3[..., -1, -1] = 1.0
  256. return rt_3x3
  257. @classmethod
  258. def from_matrix(cls, matrix: Tensor) -> Se2:
  259. """Create an Se2 group from a matrix.
  260. Args:
  261. matrix: tensor of shape :math:`(B, 3, 3)`.
  262. Example:
  263. >>> s = Se2.from_matrix(torch.eye(3).repeat(2, 1, 1))
  264. >>> s.r
  265. Parameter containing:
  266. tensor([1.+0.j, 1.+0.j], requires_grad=True)
  267. >>> s.t
  268. Parameter containing:
  269. tensor([[0., 0.],
  270. [0., 0.]], requires_grad=True)
  271. """
  272. # KORNIA_CHECK_SHAPE(matrix, ["B", "3", "3"]) # FIXME: resolve shape bugs. @edgarriba
  273. r = So2.from_matrix(matrix[..., :2, :2])
  274. t = matrix[..., :2, -1]
  275. return cls(r, t)
  276. def inverse(self) -> Se2:
  277. """Return the inverse transformation.
  278. Example:
  279. >>> s = Se2(So2.identity(1), torch.ones(1,2))
  280. >>> s_inv = s.inverse()
  281. >>> s_inv.r
  282. Parameter containing:
  283. tensor([1.+0.j], requires_grad=True)
  284. >>> s_inv.t
  285. Parameter containing:
  286. tensor([[-1., -1.]], requires_grad=True)
  287. """
  288. r_inv: So2 = self.r.inverse()
  289. _t = -1 * self.t
  290. if isinstance(_t, int):
  291. raise TypeError("Unexpected integer from `-1 * translation`")
  292. return Se2(r_inv, r_inv * _t)
  293. @classmethod
  294. def random(cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None) -> Se2:
  295. """Create a Se2 group representing a random transformation.
  296. Args:
  297. batch_size: the batch size of the underlying data.
  298. device: device to place the result on.
  299. dtype: dtype of the result.
  300. Example:
  301. >>> s = Se2.random()
  302. >>> s = Se2.random(batch_size=3)
  303. """
  304. r = So2.random(batch_size, device, dtype)
  305. shape: tuple[int, ...]
  306. if batch_size is None:
  307. shape = (2,)
  308. else:
  309. KORNIA_CHECK(batch_size >= 1, msg="batch_size must be positive")
  310. shape = (batch_size, 2)
  311. return cls(r, Vector2(rand(shape, device=device, dtype=dtype)))
  312. @classmethod
  313. def trans(cls, x: Tensor, y: Tensor) -> Se2:
  314. """Construct a translation only Se2 instance.
  315. Args:
  316. x: the x-axis translation.
  317. y: the y-axis translation.
  318. """
  319. KORNIA_CHECK(x.shape == y.shape)
  320. KORNIA_CHECK_SAME_DEVICES([x, y])
  321. batch_size = x.shape[0] if len(x.shape) > 0 else None
  322. rotation = So2.identity(batch_size, x.device, x.dtype)
  323. return cls(rotation, stack((x, y), -1))
  324. @classmethod
  325. def trans_x(cls, x: Tensor) -> Se2:
  326. """Construct a x-axis translation.
  327. Args:
  328. x: the x-axis translation.
  329. """
  330. zs = zeros_like(x)
  331. return cls.trans(x, zs)
  332. @classmethod
  333. def trans_y(cls, y: Tensor) -> Se2:
  334. """Construct a y-axis translation.
  335. Args:
  336. y: the y-axis translation.
  337. """
  338. zs = zeros_like(y)
  339. return cls.trans(zs, y)
  340. def adjoint(self) -> Tensor:
  341. """Return the adjoint matrix of shape :math:`(B, 3, 3)`.
  342. Example:
  343. >>> s = Se2.identity()
  344. >>> s.adjoint()
  345. tensor([[1., -0., 0.],
  346. [0., 1., -0.],
  347. [0., 0., 1.]], grad_fn=<CopySlices>)
  348. """
  349. rt = self.matrix()
  350. rt[..., 0:2, 2] = stack((self.t.data[..., 1], -self.t.data[..., 0]), -1)
  351. return rt