so2.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # kornia.geometry.so2 module inspired by Sophus-sympy.
  18. # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/so2.py
  19. from __future__ import annotations
  20. from typing import Optional, overload
  21. from kornia.core import Device, Dtype, Module, Parameter, Tensor, complex, rand, stack, tensor, zeros_like
  22. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR
  23. from kornia.geometry.liegroup._utils import (
  24. check_so2_matrix,
  25. check_so2_matrix_shape,
  26. check_so2_t_shape,
  27. check_so2_theta_shape,
  28. check_so2_z_shape,
  29. )
  30. from kornia.geometry.vector import Vector2
  31. class So2(Module):
  32. r"""Base class to represent the So2 group.
  33. The SO(2) is the group of all rotations about the origin of two-dimensional Euclidean space
  34. :math:`R^2` under the operation of composition.
  35. See more: https://en.wikipedia.org/wiki/Orthogonal_group#Special_orthogonal_group
  36. We internally represent the rotation by a complex number.
  37. Example:
  38. >>> real = torch.tensor([1.0])
  39. >>> imag = torch.tensor([2.0])
  40. >>> So2(torch.complex(real, imag))
  41. Parameter containing:
  42. tensor([1.+2.j], requires_grad=True)
  43. """
  44. def __init__(self, z: Tensor) -> None:
  45. """Construct the base class.
  46. Internally represented by complex number `z`.
  47. Args:
  48. z: Complex number with the shape of :math:`(B, 1)` or :math:`(B)`.
  49. Example:
  50. >>> real = torch.tensor(1.0)
  51. >>> imag = torch.tensor(2.0)
  52. >>> So2(torch.complex(real, imag)).z
  53. Parameter containing:
  54. tensor(1.+2.j, requires_grad=True)
  55. """
  56. super().__init__()
  57. KORNIA_CHECK_IS_TENSOR(z)
  58. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  59. check_so2_z_shape(z)
  60. self._z = Parameter(z)
  61. def __repr__(self) -> str:
  62. return f"{self.z}"
  63. def __getitem__(self, idx: int | slice) -> So2:
  64. return So2(self._z[idx])
  65. @overload
  66. def __mul__(self, right: So2) -> So2: ...
  67. @overload
  68. def __mul__(self, right: Tensor) -> Tensor: ...
  69. def __mul__(self, right: So2 | Tensor) -> So2 | Tensor:
  70. """Perform a left-multiplication either rotation concatenation or point-transform.
  71. Args:
  72. right: the other So2 transformation.
  73. Return:
  74. The resulting So2 transformation.
  75. """
  76. z = self.z
  77. if isinstance(right, So2):
  78. return So2(z * right.z)
  79. elif isinstance(right, (Vector2, Tensor)):
  80. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  81. if isinstance(right, Tensor):
  82. check_so2_t_shape(right)
  83. x = right.data[..., 0]
  84. y = right.data[..., 1]
  85. real = z.real
  86. imag = z.imag
  87. out = stack((real * x - imag * y, imag * x + real * y), -1)
  88. if isinstance(right, Tensor):
  89. return out
  90. else:
  91. return Vector2(out)
  92. else:
  93. raise TypeError(f"Not So2 or Tensor type. Got: {type(right)}")
  94. @property
  95. def z(self) -> Tensor:
  96. """Return the underlying data with shape :math:`(B, 1)`."""
  97. return self._z
  98. @staticmethod
  99. def exp(theta: Tensor) -> So2:
  100. """Convert elements of lie algebra to elements of lie group.
  101. Args:
  102. theta: angle in radians of shape :math:`(B, 1)` or :math:`(B)`.
  103. Example:
  104. >>> v = torch.tensor([3.1415/2])
  105. >>> s = So2.exp(v)
  106. >>> s
  107. Parameter containing:
  108. tensor([4.6329e-05+1.j], requires_grad=True)
  109. """
  110. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  111. check_so2_theta_shape(theta)
  112. return So2(complex(theta.cos(), theta.sin()))
  113. def log(self) -> Tensor:
  114. """Convert elements of lie group to elements of lie algebra.
  115. Example:
  116. >>> real = torch.tensor([1.0])
  117. >>> imag = torch.tensor([3.0])
  118. >>> So2(torch.complex(real, imag)).log()
  119. tensor([1.2490], grad_fn=<Atan2Backward0>)
  120. """
  121. return self.z.imag.atan2(self.z.real)
  122. @staticmethod
  123. def hat(theta: Tensor) -> Tensor:
  124. """Convert elements from vector space to lie algebra. Returns matrix of shape :math:`(B, 2, 2)`.
  125. Args:
  126. theta: angle in radians of shape :math:`(B)`.
  127. Example:
  128. >>> theta = torch.tensor(3.1415/2)
  129. >>> So2.hat(theta)
  130. tensor([[0.0000, 1.5707],
  131. [1.5707, 0.0000]])
  132. """
  133. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  134. check_so2_theta_shape(theta)
  135. z = zeros_like(theta)
  136. row0 = stack((z, theta), -1)
  137. row1 = stack((theta, z), -1)
  138. return stack((row0, row1), -1)
  139. @staticmethod
  140. def vee(omega: Tensor) -> Tensor:
  141. """Convert elements from lie algebra to vector space. Returns vector of shape :math:`(B,)`.
  142. Args:
  143. omega: 2x2-matrix representing lie algebra.
  144. Example:
  145. >>> v = torch.ones(3)
  146. >>> omega = So2.hat(v)
  147. >>> So2.vee(omega)
  148. tensor([1., 1., 1.])
  149. """
  150. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  151. check_so2_matrix_shape(omega)
  152. return omega[..., 0, 1]
  153. def matrix(self) -> Tensor:
  154. """Convert the complex number to a rotation matrix of shape :math:`(B, 2, 2)`.
  155. Example:
  156. >>> s = So2.identity()
  157. >>> m = s.matrix()
  158. >>> m
  159. tensor([[1., -0.],
  160. [0., 1.]], grad_fn=<StackBackward0>)
  161. """
  162. row0 = stack((self.z.real, -self.z.imag), -1)
  163. row1 = stack((self.z.imag, self.z.real), -1)
  164. return stack((row0, row1), -2)
  165. @classmethod
  166. def from_matrix(cls, matrix: Tensor) -> So2:
  167. """Create So2 from a rotation matrix.
  168. Args:
  169. matrix: the rotation matrix to convert of shape :math:`(B, 2, 2)`.
  170. Example:
  171. >>> m = torch.eye(2)
  172. >>> s = So2.from_matrix(m)
  173. >>> s.z
  174. Parameter containing:
  175. tensor(1.+0.j, requires_grad=True)
  176. """
  177. # TODO change to KORNIA_CHECK_SHAPE once there is multiple shape support
  178. check_so2_matrix_shape(matrix)
  179. check_so2_matrix(matrix)
  180. z = complex(matrix[..., 0, 0], matrix[..., 1, 0])
  181. return cls(z)
  182. @classmethod
  183. def identity(
  184. cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Optional[Dtype] = None
  185. ) -> So2:
  186. """Create a So2 group representing an identity rotation.
  187. Args:
  188. batch_size: the batch size of the underlying data.
  189. device: device to place the result on.
  190. dtype: dtype of the result.
  191. Example:
  192. >>> s = So2.identity(batch_size=2)
  193. >>> s
  194. Parameter containing:
  195. tensor([1.+0.j, 1.+0.j], requires_grad=True)
  196. """
  197. real_data = tensor(1.0, device=device, dtype=dtype)
  198. imag_data = tensor(0.0, device=device, dtype=dtype)
  199. if batch_size is not None:
  200. KORNIA_CHECK(batch_size >= 1, msg="batch_size must be positive")
  201. real_data = real_data.repeat(batch_size)
  202. imag_data = imag_data.repeat(batch_size)
  203. return cls(complex(real_data, imag_data))
  204. def inverse(self) -> So2:
  205. """Return the inverse transformation.
  206. Example:
  207. >>> s = So2.identity()
  208. >>> s.inverse().z
  209. Parameter containing:
  210. tensor(1.+0.j, requires_grad=True)
  211. """
  212. return So2(1 / self.z)
  213. @classmethod
  214. def random(
  215. cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Optional[Dtype] = None
  216. ) -> So2:
  217. """Create a So2 group representing a random rotation.
  218. Args:
  219. batch_size: the batch size of the underlying data.
  220. device: device to place the result on.
  221. dtype: dtype of the result.
  222. Example:
  223. >>> s = So2.random()
  224. >>> s = So2.random(batch_size=3)
  225. """
  226. if batch_size is not None:
  227. KORNIA_CHECK(batch_size >= 1, msg="batch_size must be positive")
  228. real_data = rand((batch_size,), device=device, dtype=dtype)
  229. imag_data = rand((batch_size,), device=device, dtype=dtype)
  230. else:
  231. real_data = rand((), device=device, dtype=dtype)
  232. imag_data = rand((), device=device, dtype=dtype)
  233. return cls(complex(real_data, imag_data))
  234. def adjoint(self) -> Tensor:
  235. """Return the adjoint matrix of shape :math:`(B, 2, 2)`.
  236. Example:
  237. >>> s = So2.identity()
  238. >>> s.adjoint()
  239. tensor([[1., -0.],
  240. [0., 1.]], grad_fn=<StackBackward0>)
  241. """
  242. batch_size = len(self.z) if len(self.z.shape) > 0 else None
  243. return self.identity(batch_size, self.z.device, self.z.real.dtype).matrix()