quaternion.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # kornia.geometry.quaternion module inspired by Eigen, Sophus-sympy, and PyQuaternion.
  18. # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/quaternion.py
  19. # https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py
  20. # https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Quaternion.h
  21. from math import pi
  22. from typing import Any, Optional, Tuple, Union
  23. import torch
  24. from kornia.core import Device, Dtype, Module, Parameter, Tensor, concatenate, rand, stack, tensor, where
  25. from kornia.core.check import KORNIA_CHECK_TYPE
  26. from kornia.geometry.conversions import (
  27. axis_angle_to_quaternion,
  28. euler_from_quaternion,
  29. normalize_quaternion,
  30. quaternion_from_euler,
  31. quaternion_to_axis_angle,
  32. quaternion_to_rotation_matrix,
  33. rotation_matrix_to_quaternion,
  34. )
  35. from kornia.geometry.linalg import batched_dot_product
  36. class Quaternion(Module):
  37. r"""Base class to represent a Quaternion.
  38. A quaternion is a four dimensional vector representation of a rotation transformation in 3d.
  39. See more: https://en.wikipedia.org/wiki/Quaternion
  40. The general definition of a quaternion is given by:
  41. .. math::
  42. Q = a + b \cdot \mathbf{i} + c \cdot \mathbf{j} + d \cdot \mathbf{k}
  43. Thus, we represent a rotation quaternion as a contiguous tensor structure to
  44. perform rigid bodies transformations:
  45. .. math::
  46. Q = \begin{bmatrix} q_w & q_x & q_y & q_z \end{bmatrix}
  47. Example:
  48. >>> q = Quaternion.identity(batch_size=4)
  49. >>> q.data
  50. tensor([[1., 0., 0., 0.],
  51. [1., 0., 0., 0.],
  52. [1., 0., 0., 0.],
  53. [1., 0., 0., 0.]])
  54. >>> q.real
  55. tensor([1., 1., 1., 1.])
  56. >>> q.vec
  57. tensor([[0., 0., 0.],
  58. [0., 0., 0.],
  59. [0., 0., 0.],
  60. [0., 0., 0.]])
  61. """
  62. _data: Union[Tensor, Parameter]
  63. def __init__(self, data: Union[Tensor, Parameter]) -> None:
  64. """Construct a quaternion from tensor or parameter data.
  65. Args:
  66. data: tensor or parameter containing the quaternion data with the shape of :math:`(B, 4)`.
  67. Example:
  68. >>> # Create with tensor (no gradients tracked by default)
  69. >>> data = torch.tensor([1., 0., 0., 0.])
  70. >>> q1 = Quaternion(data)
  71. >>> # Create with parameter (gradients tracked)
  72. >>> param_data = torch.nn.Parameter(torch.tensor([1., 0., 0., 0.]))
  73. >>> q2 = Quaternion(param_data)
  74. """
  75. super().__init__()
  76. if not isinstance(data, (Tensor, Parameter)):
  77. raise TypeError(f"Expected Tensor or Parameter, got {type(data)}")
  78. # KORNIA_CHECK_SHAPE(data, ["B", "4"]) # FIXME: resolve shape bugs. @edgarriba
  79. self._data = data
  80. def to(self, *args: Any, **kwargs: Any) -> "Quaternion":
  81. """Move and/or cast the quaternion data.
  82. Args:
  83. *args: Arguments to pass to tensor.to()
  84. **kwargs: Keyword arguments to pass to tensor.to()
  85. Returns:
  86. A new Quaternion with converted data.
  87. """
  88. return Quaternion(self._data.to(*args, **kwargs))
  89. def _to_scalar_quaternion(self, value: Union[Tensor, float]) -> "Quaternion":
  90. """Convert a scalar, tensor, or numeric value to a scalar quaternion.
  91. A scalar quaternion has the form [real, 0, 0, 0] where real is the input value.
  92. Args:
  93. value: The scalar, tensor, or numeric value to convert.
  94. Returns:
  95. A Quaternion object representing the scalar quaternion.
  96. """
  97. if isinstance(value, (int, float)):
  98. value = torch.tensor(value, device=self.data.device, dtype=self.data.dtype)
  99. elif isinstance(value, torch.Tensor):
  100. value = value.to(device=self.data.device, dtype=self.data.dtype)
  101. # Broadcast value to match the shape of self.real
  102. try:
  103. target_shape = torch.broadcast_shapes(self.real.shape, value.shape)
  104. except RuntimeError as e:
  105. raise ValueError(f"Cannot broadcast shapes {self.real.shape} and {value.shape}") from e
  106. broadcasted = self.real.expand(target_shape) + value.expand(target_shape)
  107. # Create scalar quaternion: [value, 0, 0, 0]
  108. # Expand value to match the broadcasted shape, then add quaternion dimension
  109. if value.dim() == 0: # scalar
  110. # Expand to match the broadcasted shape
  111. expanded_value = value.expand_as(broadcasted)
  112. else:
  113. # Use broadcasting to get the right shape
  114. expanded_value = torch.broadcast_to(value, broadcasted.shape)
  115. # Create zeros for the imaginary part
  116. zeros = torch.zeros_like(expanded_value).unsqueeze(-1).expand(*expanded_value.shape, 3)
  117. # Stack real and imaginary parts: [real, 0, 0, 0]
  118. scalar_quat_data = torch.cat([expanded_value.unsqueeze(-1), zeros], dim=-1)
  119. return Quaternion(scalar_quat_data)
  120. def __repr__(self) -> str:
  121. return f"{self.data}"
  122. def __getitem__(self, idx: Union[int, slice]) -> "Quaternion":
  123. return Quaternion(self.data[idx])
  124. def __neg__(self) -> "Quaternion":
  125. """Inverts the sign of the quaternion data.
  126. Example:
  127. >>> q = Quaternion.identity()
  128. >>> -q.data
  129. tensor([-1., -0., -0., -0.])
  130. """
  131. return Quaternion(-self.data)
  132. def __add__(self, right: Union["Quaternion", Tensor, float]) -> "Quaternion":
  133. """Add a given quaternion, scalar, or tensor.
  134. Args:
  135. right: the quaternion, scalar, or tensor to add.
  136. Example:
  137. >>> q1 = Quaternion.identity()
  138. >>> q2 = Quaternion(tensor([2., 0., 1., 1.]))
  139. >>> q3 = q1 + q2
  140. >>> q3.data
  141. tensor([3., 0., 1., 1.])
  142. """
  143. if isinstance(right, Quaternion):
  144. return Quaternion(self.data + right.data)
  145. else:
  146. right_quat = self._to_scalar_quaternion(right)
  147. return Quaternion(self.data + right_quat.data)
  148. def __sub__(self, right: Union["Quaternion", Tensor, float]) -> "Quaternion":
  149. """Subtract a given quaternion, scalar, or tensor.
  150. Args:
  151. right: the quaternion, scalar, or tensor to subtract.
  152. Example:
  153. >>> q1 = Quaternion(tensor([2., 0., 1., 1.]))
  154. >>> q2 = Quaternion.identity()
  155. >>> q3 = q1 - q2
  156. >>> q3.data
  157. tensor([1., 0., 1., 1.])
  158. """
  159. if isinstance(right, Quaternion):
  160. return Quaternion(self.data - right.data)
  161. else:
  162. right_quat = self._to_scalar_quaternion(right)
  163. # For scalar operations, ensure we return a tensor to preserve gradients
  164. result_data = self.data - right_quat.data
  165. if isinstance(result_data, Parameter):
  166. result_data = result_data.data # Convert to tensor to preserve gradients
  167. return Quaternion(result_data)
  168. def __mul__(self, right: Union["Quaternion", Tensor, float]) -> "Quaternion":
  169. # If right is a Quaternion, do quaternion multiplication
  170. if isinstance(right, Quaternion):
  171. new_real = self.real * right.real - batched_dot_product(self.vec, right.vec)
  172. new_vec = (
  173. self.real[..., None] * right.vec
  174. + right.real[..., None] * self.vec
  175. + torch.linalg.cross(self.vec, right.vec, dim=-1)
  176. )
  177. return Quaternion(concatenate((new_real[..., None], new_vec), -1))
  178. # If right is a scalar/tensor, convert to scalar quaternion and multiply
  179. else:
  180. right_quat = self._to_scalar_quaternion(right)
  181. new_real = self.real * right_quat.real - batched_dot_product(self.vec, right_quat.vec)
  182. new_vec = (
  183. self.real[..., None] * right_quat.vec
  184. + right_quat.real[..., None] * self.vec
  185. + torch.linalg.cross(self.vec, right_quat.vec, dim=-1)
  186. )
  187. return Quaternion(concatenate((new_real[..., None], new_vec), -1))
  188. def __rmul__(self, left: Union[Tensor, float]) -> "Quaternion":
  189. """Right multiplication (left * self) where left is a scalar or tensor."""
  190. left_quat = self._to_scalar_quaternion(left)
  191. new_real = left_quat.real * self.real - batched_dot_product(left_quat.vec, self.vec)
  192. new_vec = (
  193. left_quat.real[..., None] * self.vec
  194. + self.real[..., None] * left_quat.vec
  195. + torch.linalg.cross(left_quat.vec, self.vec, dim=-1)
  196. )
  197. return Quaternion(concatenate((new_real[..., None], new_vec), -1))
  198. def __div__(self, right: Union[Tensor, "Quaternion", float]) -> "Quaternion":
  199. if isinstance(right, Quaternion):
  200. return self * right.inv()
  201. else:
  202. # For scalars/tensors, just divide the quaternion data directly
  203. if isinstance(right, (int, float)):
  204. right_tensor = torch.tensor(right, device=self.data.device, dtype=self.data.dtype)
  205. else:
  206. right_tensor = right.to(device=self.data.device, dtype=self.data.dtype)
  207. # For division by scalar, expand to [right, right, right, right] for element-wise division
  208. if right_tensor.dim() == 0: # scalar
  209. divisor = right_tensor.expand_as(self.data[..., 0]).unsqueeze(-1).expand_as(self.data)
  210. else:
  211. # Broadcast the tensor to match the quaternion dimensions
  212. divisor = right_tensor.unsqueeze(-1).expand_as(self.data)
  213. # For scalar operations, ensure we return a tensor to preserve gradients
  214. result_data = self.data / divisor
  215. if isinstance(result_data, Parameter):
  216. result_data = result_data.data # Convert to tensor to preserve gradients
  217. return Quaternion(result_data)
  218. def __truediv__(self, right: Union[Tensor, "Quaternion", float]) -> "Quaternion":
  219. return self.__div__(right)
  220. def __radd__(self, left: Union[Tensor, float]) -> "Quaternion":
  221. """Right addition (left + self) where left is a scalar or tensor."""
  222. left_quat = self._to_scalar_quaternion(left)
  223. return left_quat + self
  224. def __rsub__(self, left: Union[Tensor, float]) -> "Quaternion":
  225. """Right subtraction (left - self) where left is a scalar or tensor."""
  226. left_quat = self._to_scalar_quaternion(left)
  227. return left_quat - self
  228. def __rtruediv__(self, left: Union[Tensor, float]) -> "Quaternion":
  229. """Right division (left / self) where left is a scalar or tensor."""
  230. left_quat = self._to_scalar_quaternion(left)
  231. return left_quat / self
  232. def __rdiv__(self, left: Union[Tensor, float]) -> "Quaternion":
  233. """Right division (left / self) where left is a scalar or tensor."""
  234. return self.__rtruediv__(left)
  235. def __pow__(self, t: float) -> "Quaternion":
  236. """Return the power of a quaternion raised to exponent t.
  237. Args:
  238. t: raised exponent.
  239. Example:
  240. >>> q = Quaternion(tensor([1., .5, 0., 0.]))
  241. >>> q_pow = q**2
  242. """
  243. theta = self.polar_angle[..., None]
  244. vec_norm = self.vec.norm(dim=-1, keepdim=True)
  245. n = where(vec_norm != 0, self.vec / vec_norm, self.vec * 0)
  246. w = (t * theta).cos()
  247. xyz = (t * theta).sin() * n
  248. return Quaternion(concatenate((w, xyz), -1))
  249. @property
  250. def data(self) -> Tensor:
  251. """Return the underlying data with shape :math:`(B, 4)`."""
  252. return self._data
  253. @property
  254. def coeffs(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  255. """Return a tuple with the underlying coefficients in WXYZ order."""
  256. return self.w, self.x, self.y, self.z
  257. @property
  258. def real(self) -> Tensor:
  259. """Return the real part with shape :math:`(B,)`.
  260. Alias for
  261. :func: `~kornia.geometry.quaternion.Quaternion.w`
  262. """
  263. return self.w
  264. @property
  265. def vec(self) -> Tensor:
  266. """Return the vector with the imaginary part with shape :math:`(B, 3)`."""
  267. return self.data[..., 1:]
  268. @property
  269. def q(self) -> Tensor:
  270. """Return the underlying data with shape :math:`(B, 4)`.
  271. Alias for :func:`~kornia.geometry.quaternion.Quaternion.data`
  272. """
  273. return self.data
  274. @property
  275. def scalar(self) -> Tensor:
  276. """Return a scalar with the real with shape :math:`(B,)`.
  277. Alias for
  278. :func: `~kornia.geometry.quaternion.Quaternion.w`
  279. """
  280. return self.real
  281. @property
  282. def w(self) -> Tensor:
  283. """Return the :math:`q_w` with shape :math:`(B,)`."""
  284. return self.data[..., 0]
  285. @property
  286. def x(self) -> Tensor:
  287. """Return the :math:`q_x` with shape :math:`(B,)`."""
  288. return self.data[..., 1]
  289. @property
  290. def y(self) -> Tensor:
  291. """Return the :math:`q_y` with shape :math:`(B,)`."""
  292. return self.data[..., 2]
  293. @property
  294. def z(self) -> Tensor:
  295. """Return the :math:`q_z` with shape :math:`(B,)`."""
  296. return self.data[..., 3]
  297. @property
  298. def shape(self) -> Tuple[int, ...]:
  299. """Return the shape of the underlying data with shape :math:`(B, 4)`."""
  300. return tuple(self.data.shape)
  301. @property
  302. def polar_angle(self) -> Tensor:
  303. """Return the polar angle with shape :math:`(B,1)`.
  304. Example:
  305. >>> q = Quaternion.identity()
  306. >>> q.polar_angle
  307. tensor(0.)
  308. """
  309. return (self.scalar / self.norm()).acos()
  310. def matrix(self) -> Tensor:
  311. """Convert the quaternion to a rotation matrix of shape :math:`(B, 3, 3)`.
  312. Example:
  313. >>> q = Quaternion.identity()
  314. >>> m = q.matrix()
  315. >>> m
  316. tensor([[1., 0., 0.],
  317. [0., 1., 0.],
  318. [0., 0., 1.]])
  319. """
  320. return quaternion_to_rotation_matrix(self.data)
  321. @classmethod
  322. def from_matrix(cls, matrix: Tensor) -> "Quaternion":
  323. """Create a quaternion from a rotation matrix.
  324. Args:
  325. matrix: the rotation matrix to convert of shape :math:`(B, 3, 3)`.
  326. Example:
  327. >>> m = torch.eye(3)[None]
  328. >>> q = Quaternion.from_matrix(m)
  329. >>> q.data
  330. tensor([[1., 0., 0., 0.]])
  331. """
  332. return cls(rotation_matrix_to_quaternion(matrix))
  333. @classmethod
  334. def from_euler(cls, roll: Tensor, pitch: Tensor, yaw: Tensor) -> "Quaternion":
  335. """Create a quaternion from Euler angles.
  336. Args:
  337. roll: the roll euler angle.
  338. pitch: the pitch euler angle.
  339. yaw: the yaw euler angle.
  340. Example:
  341. >>> roll, pitch, yaw = tensor(0), tensor(1), tensor(0)
  342. >>> q = Quaternion.from_euler(roll, pitch, yaw)
  343. >>> q.data
  344. tensor([0.8776, 0.0000, 0.4794, 0.0000])
  345. """
  346. w, x, y, z = quaternion_from_euler(roll=roll, pitch=pitch, yaw=yaw)
  347. q = stack((w, x, y, z), -1)
  348. return cls(q)
  349. def to_euler(self) -> Tuple[Tensor, Tensor, Tensor]:
  350. """Convert the quaternion to a triple of Euler angles (roll, pitch, yaw).
  351. Example:
  352. >>> q = Quaternion(tensor([2., 0., 1., 1.]))
  353. >>> roll, pitch, yaw = q.to_euler()
  354. >>> roll
  355. tensor(2.0344)
  356. >>> pitch
  357. tensor(1.5708)
  358. >>> yaw
  359. tensor(2.2143)
  360. """
  361. return euler_from_quaternion(self.w, self.x, self.y, self.z)
  362. @classmethod
  363. def from_axis_angle(cls, axis_angle: Tensor) -> "Quaternion":
  364. """Create a quaternion from axis-angle representation.
  365. Args:
  366. axis_angle: rotation vector of shape :math:`(B, 3)`.
  367. Example:
  368. >>> axis_angle = torch.tensor([[1., 0., 0.]])
  369. >>> q = Quaternion.from_axis_angle(axis_angle)
  370. >>> q.data
  371. tensor([[0.8776, 0.4794, 0.0000, 0.0000]])
  372. """
  373. return cls(axis_angle_to_quaternion(axis_angle))
  374. def to_axis_angle(self) -> Tensor:
  375. """Convert the quaternion to an axis-angle representation.
  376. Example:
  377. >>> q = Quaternion.identity()
  378. >>> axis_angle = q.to_axis_angle()
  379. >>> axis_angle
  380. tensor([0., 0., 0.])
  381. """
  382. return quaternion_to_axis_angle(self.data)
  383. @classmethod
  384. def identity(
  385. cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None
  386. ) -> "Quaternion":
  387. """Create a quaternion representing an identity rotation.
  388. Args:
  389. batch_size: the batch size of the underlying data.
  390. device: device to place the result on.
  391. dtype: dtype of the result.
  392. Example:
  393. >>> q = Quaternion.identity()
  394. >>> q.data
  395. tensor([1., 0., 0., 0.])
  396. """
  397. data = tensor([1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype)
  398. if batch_size is not None:
  399. data = data.repeat(batch_size, 1)
  400. return cls(data)
  401. @classmethod
  402. def from_coeffs(cls, w: float, x: float, y: float, z: float) -> "Quaternion":
  403. """Create a quaternion from the data coefficients.
  404. Args:
  405. w: a float representing the :math:`q_w` component.
  406. x: a float representing the :math:`q_x` component.
  407. y: a float representing the :math:`q_y` component.
  408. z: a float representing the :math:`q_z` component.
  409. Example:
  410. >>> q = Quaternion.from_coeffs(1., 0., 0., 0.)
  411. >>> q.data
  412. tensor([1., 0., 0., 0.])
  413. """
  414. return cls(tensor([w, x, y, z]))
  415. # TODO: update signature
  416. # def random(cls, shape: Optional[List] = None, device = None, dtype = None) -> 'Quaternion':
  417. @classmethod
  418. def random(
  419. cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None
  420. ) -> "Quaternion":
  421. """Create a random unit quaternion of shape :math:`(B, 4)`.
  422. Uniformly distributed across the rotation space as per: http://planning.cs.uiuc.edu/node198.html
  423. Args:
  424. batch_size: the batch size of the underlying data.
  425. device: device to place the result on.
  426. dtype: dtype of the result.
  427. Example:
  428. >>> q = Quaternion.random()
  429. >>> q = Quaternion.random(batch_size=2)
  430. """
  431. rand_shape = (batch_size,) if batch_size is not None else ()
  432. r1, r2, r3 = rand((3, *rand_shape), device=device, dtype=dtype)
  433. q1 = (1.0 - r1).sqrt() * ((2 * pi * r2).sin())
  434. q2 = (1.0 - r1).sqrt() * ((2 * pi * r2).cos())
  435. q3 = r1.sqrt() * (2 * pi * r3).sin()
  436. q4 = r1.sqrt() * (2 * pi * r3).cos()
  437. return cls(stack((q1, q2, q3, q4), -1))
  438. def slerp(self, q1: "Quaternion", t: float) -> "Quaternion":
  439. """Return a unit quaternion spherically interpolated between quaternions self.q and q1.
  440. See more: https://en.wikipedia.org/wiki/Slerp
  441. Args:
  442. q1: second quaternion to be interpolated between.
  443. t: interpolation ratio, range [0-1]
  444. Example:
  445. >>> q0 = Quaternion.identity()
  446. >>> q1 = Quaternion(torch.tensor([1., .5, 0., 0.]))
  447. >>> q2 = q0.slerp(q1, .3)
  448. """
  449. KORNIA_CHECK_TYPE(q1, Quaternion)
  450. q0 = self.normalize()
  451. q1 = q1.normalize()
  452. return q0 * (q0.inv() * q1) ** t
  453. def norm(self, keepdim: bool = False) -> Tensor:
  454. """Compute the norm (magnitude) of the quaternion.
  455. Args:
  456. keepdim: whether to retain the last dimension.
  457. Returns:
  458. The norm of the quaternion(s) as a tensor.
  459. Example:
  460. >>> q = Quaternion.identity()
  461. >>> q.norm()
  462. tensor(1.)
  463. """
  464. # p==2, dim|axis==-1, keepdim
  465. return self.data.norm(2, -1, keepdim)
  466. def normalize(self) -> "Quaternion":
  467. """Return a normalized (unit) quaternion.
  468. Returns:
  469. The normalized quaternion.
  470. Example:
  471. >>> q = Quaternion(tensor([2., 1., 0., 0.]))
  472. >>> q_norm = q.normalize()
  473. """
  474. return Quaternion(normalize_quaternion(self.data))
  475. def conj(self) -> "Quaternion":
  476. """Compute the conjugate of the quaternion.
  477. Returns:
  478. The conjugate quaternion, with the vector part negated.
  479. Example:
  480. >>> q = Quaternion(tensor([1., 2., 3., 4.]))
  481. >>> q_conj = q.conj()
  482. """
  483. return Quaternion(concatenate((self.real[..., None], -self.vec), -1))
  484. def inv(self) -> "Quaternion":
  485. """Compute the inverse of the quaternion.
  486. Returns:
  487. The inverse quaternion.
  488. Example:
  489. >>> q = Quaternion.identity()
  490. >>> q_inv = q.inv()
  491. """
  492. return self.conj() / self.squared_norm()
  493. def squared_norm(self) -> Tensor:
  494. """Compute the squared norm (magnitude) of the quaternion.
  495. Returns:
  496. The squared norm of the quaternion(s) as a tensor.
  497. Example:
  498. >>> q = Quaternion.identity()
  499. >>> q.squared_norm()
  500. tensor(1.)
  501. """
  502. return batched_dot_product(self.vec, self.vec) + self.real**2
  503. def average_quaternions(Q: "Quaternion", w: Optional[torch.Tensor] = None) -> "Quaternion":
  504. """Compute (weighted) average of multiple quaternions.
  505. Args:
  506. Q (Quaternion): quaternion object containing data of shape (M, 4).
  507. w (torch.Tensor, optional): Weights of shape (M,). If None, uniform weights are used.
  508. Returns:
  509. Quaternion: averaged quaternion (shape (4,)), wrapped back in the Quaternion class.
  510. """
  511. data = Q.data
  512. KORNIA_CHECK_TYPE(Q, Quaternion)
  513. M = data.shape[0]
  514. if w is None:
  515. A = (data.T @ data) / M
  516. else:
  517. w = w.to(data.device, dtype=data.dtype)
  518. if w.numel() != M:
  519. raise ValueError(f"weights length {w.numel()} must match number of quaternions {M}")
  520. w = w / w.sum()
  521. A = data.T @ torch.diag(w) @ data
  522. eigenvalues, eigenvectors = torch.linalg.eigh(A)
  523. q_avg = eigenvectors[:, torch.argmax(eigenvalues)]
  524. q_avg = q_avg / q_avg.norm()
  525. return Quaternion(q_avg.unsqueeze(0))