| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- # kornia.geometry.quaternion module inspired by Eigen, Sophus-sympy, and PyQuaternion.
- # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/quaternion.py
- # https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py
- # https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Quaternion.h
- from math import pi
- from typing import Any, Optional, Tuple, Union
- import torch
- from kornia.core import Device, Dtype, Module, Parameter, Tensor, concatenate, rand, stack, tensor, where
- from kornia.core.check import KORNIA_CHECK_TYPE
- from kornia.geometry.conversions import (
- axis_angle_to_quaternion,
- euler_from_quaternion,
- normalize_quaternion,
- quaternion_from_euler,
- quaternion_to_axis_angle,
- quaternion_to_rotation_matrix,
- rotation_matrix_to_quaternion,
- )
- from kornia.geometry.linalg import batched_dot_product
- class Quaternion(Module):
- r"""Base class to represent a Quaternion.
- A quaternion is a four dimensional vector representation of a rotation transformation in 3d.
- See more: https://en.wikipedia.org/wiki/Quaternion
- The general definition of a quaternion is given by:
- .. math::
- Q = a + b \cdot \mathbf{i} + c \cdot \mathbf{j} + d \cdot \mathbf{k}
- Thus, we represent a rotation quaternion as a contiguous tensor structure to
- perform rigid bodies transformations:
- .. math::
- Q = \begin{bmatrix} q_w & q_x & q_y & q_z \end{bmatrix}
- Example:
- >>> q = Quaternion.identity(batch_size=4)
- >>> q.data
- tensor([[1., 0., 0., 0.],
- [1., 0., 0., 0.],
- [1., 0., 0., 0.],
- [1., 0., 0., 0.]])
- >>> q.real
- tensor([1., 1., 1., 1.])
- >>> q.vec
- tensor([[0., 0., 0.],
- [0., 0., 0.],
- [0., 0., 0.],
- [0., 0., 0.]])
- """
- _data: Union[Tensor, Parameter]
- def __init__(self, data: Union[Tensor, Parameter]) -> None:
- """Construct a quaternion from tensor or parameter data.
- Args:
- data: tensor or parameter containing the quaternion data with the shape of :math:`(B, 4)`.
- Example:
- >>> # Create with tensor (no gradients tracked by default)
- >>> data = torch.tensor([1., 0., 0., 0.])
- >>> q1 = Quaternion(data)
- >>> # Create with parameter (gradients tracked)
- >>> param_data = torch.nn.Parameter(torch.tensor([1., 0., 0., 0.]))
- >>> q2 = Quaternion(param_data)
- """
- super().__init__()
- if not isinstance(data, (Tensor, Parameter)):
- raise TypeError(f"Expected Tensor or Parameter, got {type(data)}")
- # KORNIA_CHECK_SHAPE(data, ["B", "4"]) # FIXME: resolve shape bugs. @edgarriba
- self._data = data
- def to(self, *args: Any, **kwargs: Any) -> "Quaternion":
- """Move and/or cast the quaternion data.
- Args:
- *args: Arguments to pass to tensor.to()
- **kwargs: Keyword arguments to pass to tensor.to()
- Returns:
- A new Quaternion with converted data.
- """
- return Quaternion(self._data.to(*args, **kwargs))
- def _to_scalar_quaternion(self, value: Union[Tensor, float]) -> "Quaternion":
- """Convert a scalar, tensor, or numeric value to a scalar quaternion.
- A scalar quaternion has the form [real, 0, 0, 0] where real is the input value.
- Args:
- value: The scalar, tensor, or numeric value to convert.
- Returns:
- A Quaternion object representing the scalar quaternion.
- """
- if isinstance(value, (int, float)):
- value = torch.tensor(value, device=self.data.device, dtype=self.data.dtype)
- elif isinstance(value, torch.Tensor):
- value = value.to(device=self.data.device, dtype=self.data.dtype)
- # Broadcast value to match the shape of self.real
- try:
- target_shape = torch.broadcast_shapes(self.real.shape, value.shape)
- except RuntimeError as e:
- raise ValueError(f"Cannot broadcast shapes {self.real.shape} and {value.shape}") from e
- broadcasted = self.real.expand(target_shape) + value.expand(target_shape)
- # Create scalar quaternion: [value, 0, 0, 0]
- # Expand value to match the broadcasted shape, then add quaternion dimension
- if value.dim() == 0: # scalar
- # Expand to match the broadcasted shape
- expanded_value = value.expand_as(broadcasted)
- else:
- # Use broadcasting to get the right shape
- expanded_value = torch.broadcast_to(value, broadcasted.shape)
- # Create zeros for the imaginary part
- zeros = torch.zeros_like(expanded_value).unsqueeze(-1).expand(*expanded_value.shape, 3)
- # Stack real and imaginary parts: [real, 0, 0, 0]
- scalar_quat_data = torch.cat([expanded_value.unsqueeze(-1), zeros], dim=-1)
- return Quaternion(scalar_quat_data)
- def __repr__(self) -> str:
- return f"{self.data}"
- def __getitem__(self, idx: Union[int, slice]) -> "Quaternion":
- return Quaternion(self.data[idx])
- def __neg__(self) -> "Quaternion":
- """Inverts the sign of the quaternion data.
- Example:
- >>> q = Quaternion.identity()
- >>> -q.data
- tensor([-1., -0., -0., -0.])
- """
- return Quaternion(-self.data)
- def __add__(self, right: Union["Quaternion", Tensor, float]) -> "Quaternion":
- """Add a given quaternion, scalar, or tensor.
- Args:
- right: the quaternion, scalar, or tensor to add.
- Example:
- >>> q1 = Quaternion.identity()
- >>> q2 = Quaternion(tensor([2., 0., 1., 1.]))
- >>> q3 = q1 + q2
- >>> q3.data
- tensor([3., 0., 1., 1.])
- """
- if isinstance(right, Quaternion):
- return Quaternion(self.data + right.data)
- else:
- right_quat = self._to_scalar_quaternion(right)
- return Quaternion(self.data + right_quat.data)
- def __sub__(self, right: Union["Quaternion", Tensor, float]) -> "Quaternion":
- """Subtract a given quaternion, scalar, or tensor.
- Args:
- right: the quaternion, scalar, or tensor to subtract.
- Example:
- >>> q1 = Quaternion(tensor([2., 0., 1., 1.]))
- >>> q2 = Quaternion.identity()
- >>> q3 = q1 - q2
- >>> q3.data
- tensor([1., 0., 1., 1.])
- """
- if isinstance(right, Quaternion):
- return Quaternion(self.data - right.data)
- else:
- right_quat = self._to_scalar_quaternion(right)
- # For scalar operations, ensure we return a tensor to preserve gradients
- result_data = self.data - right_quat.data
- if isinstance(result_data, Parameter):
- result_data = result_data.data # Convert to tensor to preserve gradients
- return Quaternion(result_data)
- def __mul__(self, right: Union["Quaternion", Tensor, float]) -> "Quaternion":
- # If right is a Quaternion, do quaternion multiplication
- if isinstance(right, Quaternion):
- new_real = self.real * right.real - batched_dot_product(self.vec, right.vec)
- new_vec = (
- self.real[..., None] * right.vec
- + right.real[..., None] * self.vec
- + torch.linalg.cross(self.vec, right.vec, dim=-1)
- )
- return Quaternion(concatenate((new_real[..., None], new_vec), -1))
- # If right is a scalar/tensor, convert to scalar quaternion and multiply
- else:
- right_quat = self._to_scalar_quaternion(right)
- new_real = self.real * right_quat.real - batched_dot_product(self.vec, right_quat.vec)
- new_vec = (
- self.real[..., None] * right_quat.vec
- + right_quat.real[..., None] * self.vec
- + torch.linalg.cross(self.vec, right_quat.vec, dim=-1)
- )
- return Quaternion(concatenate((new_real[..., None], new_vec), -1))
- def __rmul__(self, left: Union[Tensor, float]) -> "Quaternion":
- """Right multiplication (left * self) where left is a scalar or tensor."""
- left_quat = self._to_scalar_quaternion(left)
- new_real = left_quat.real * self.real - batched_dot_product(left_quat.vec, self.vec)
- new_vec = (
- left_quat.real[..., None] * self.vec
- + self.real[..., None] * left_quat.vec
- + torch.linalg.cross(left_quat.vec, self.vec, dim=-1)
- )
- return Quaternion(concatenate((new_real[..., None], new_vec), -1))
- def __div__(self, right: Union[Tensor, "Quaternion", float]) -> "Quaternion":
- if isinstance(right, Quaternion):
- return self * right.inv()
- else:
- # For scalars/tensors, just divide the quaternion data directly
- if isinstance(right, (int, float)):
- right_tensor = torch.tensor(right, device=self.data.device, dtype=self.data.dtype)
- else:
- right_tensor = right.to(device=self.data.device, dtype=self.data.dtype)
- # For division by scalar, expand to [right, right, right, right] for element-wise division
- if right_tensor.dim() == 0: # scalar
- divisor = right_tensor.expand_as(self.data[..., 0]).unsqueeze(-1).expand_as(self.data)
- else:
- # Broadcast the tensor to match the quaternion dimensions
- divisor = right_tensor.unsqueeze(-1).expand_as(self.data)
- # For scalar operations, ensure we return a tensor to preserve gradients
- result_data = self.data / divisor
- if isinstance(result_data, Parameter):
- result_data = result_data.data # Convert to tensor to preserve gradients
- return Quaternion(result_data)
- def __truediv__(self, right: Union[Tensor, "Quaternion", float]) -> "Quaternion":
- return self.__div__(right)
- def __radd__(self, left: Union[Tensor, float]) -> "Quaternion":
- """Right addition (left + self) where left is a scalar or tensor."""
- left_quat = self._to_scalar_quaternion(left)
- return left_quat + self
- def __rsub__(self, left: Union[Tensor, float]) -> "Quaternion":
- """Right subtraction (left - self) where left is a scalar or tensor."""
- left_quat = self._to_scalar_quaternion(left)
- return left_quat - self
- def __rtruediv__(self, left: Union[Tensor, float]) -> "Quaternion":
- """Right division (left / self) where left is a scalar or tensor."""
- left_quat = self._to_scalar_quaternion(left)
- return left_quat / self
- def __rdiv__(self, left: Union[Tensor, float]) -> "Quaternion":
- """Right division (left / self) where left is a scalar or tensor."""
- return self.__rtruediv__(left)
- def __pow__(self, t: float) -> "Quaternion":
- """Return the power of a quaternion raised to exponent t.
- Args:
- t: raised exponent.
- Example:
- >>> q = Quaternion(tensor([1., .5, 0., 0.]))
- >>> q_pow = q**2
- """
- theta = self.polar_angle[..., None]
- vec_norm = self.vec.norm(dim=-1, keepdim=True)
- n = where(vec_norm != 0, self.vec / vec_norm, self.vec * 0)
- w = (t * theta).cos()
- xyz = (t * theta).sin() * n
- return Quaternion(concatenate((w, xyz), -1))
- @property
- def data(self) -> Tensor:
- """Return the underlying data with shape :math:`(B, 4)`."""
- return self._data
- @property
- def coeffs(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- """Return a tuple with the underlying coefficients in WXYZ order."""
- return self.w, self.x, self.y, self.z
- @property
- def real(self) -> Tensor:
- """Return the real part with shape :math:`(B,)`.
- Alias for
- :func: `~kornia.geometry.quaternion.Quaternion.w`
- """
- return self.w
- @property
- def vec(self) -> Tensor:
- """Return the vector with the imaginary part with shape :math:`(B, 3)`."""
- return self.data[..., 1:]
- @property
- def q(self) -> Tensor:
- """Return the underlying data with shape :math:`(B, 4)`.
- Alias for :func:`~kornia.geometry.quaternion.Quaternion.data`
- """
- return self.data
- @property
- def scalar(self) -> Tensor:
- """Return a scalar with the real with shape :math:`(B,)`.
- Alias for
- :func: `~kornia.geometry.quaternion.Quaternion.w`
- """
- return self.real
- @property
- def w(self) -> Tensor:
- """Return the :math:`q_w` with shape :math:`(B,)`."""
- return self.data[..., 0]
- @property
- def x(self) -> Tensor:
- """Return the :math:`q_x` with shape :math:`(B,)`."""
- return self.data[..., 1]
- @property
- def y(self) -> Tensor:
- """Return the :math:`q_y` with shape :math:`(B,)`."""
- return self.data[..., 2]
- @property
- def z(self) -> Tensor:
- """Return the :math:`q_z` with shape :math:`(B,)`."""
- return self.data[..., 3]
- @property
- def shape(self) -> Tuple[int, ...]:
- """Return the shape of the underlying data with shape :math:`(B, 4)`."""
- return tuple(self.data.shape)
- @property
- def polar_angle(self) -> Tensor:
- """Return the polar angle with shape :math:`(B,1)`.
- Example:
- >>> q = Quaternion.identity()
- >>> q.polar_angle
- tensor(0.)
- """
- return (self.scalar / self.norm()).acos()
- def matrix(self) -> Tensor:
- """Convert the quaternion to a rotation matrix of shape :math:`(B, 3, 3)`.
- Example:
- >>> q = Quaternion.identity()
- >>> m = q.matrix()
- >>> m
- tensor([[1., 0., 0.],
- [0., 1., 0.],
- [0., 0., 1.]])
- """
- return quaternion_to_rotation_matrix(self.data)
- @classmethod
- def from_matrix(cls, matrix: Tensor) -> "Quaternion":
- """Create a quaternion from a rotation matrix.
- Args:
- matrix: the rotation matrix to convert of shape :math:`(B, 3, 3)`.
- Example:
- >>> m = torch.eye(3)[None]
- >>> q = Quaternion.from_matrix(m)
- >>> q.data
- tensor([[1., 0., 0., 0.]])
- """
- return cls(rotation_matrix_to_quaternion(matrix))
- @classmethod
- def from_euler(cls, roll: Tensor, pitch: Tensor, yaw: Tensor) -> "Quaternion":
- """Create a quaternion from Euler angles.
- Args:
- roll: the roll euler angle.
- pitch: the pitch euler angle.
- yaw: the yaw euler angle.
- Example:
- >>> roll, pitch, yaw = tensor(0), tensor(1), tensor(0)
- >>> q = Quaternion.from_euler(roll, pitch, yaw)
- >>> q.data
- tensor([0.8776, 0.0000, 0.4794, 0.0000])
- """
- w, x, y, z = quaternion_from_euler(roll=roll, pitch=pitch, yaw=yaw)
- q = stack((w, x, y, z), -1)
- return cls(q)
- def to_euler(self) -> Tuple[Tensor, Tensor, Tensor]:
- """Convert the quaternion to a triple of Euler angles (roll, pitch, yaw).
- Example:
- >>> q = Quaternion(tensor([2., 0., 1., 1.]))
- >>> roll, pitch, yaw = q.to_euler()
- >>> roll
- tensor(2.0344)
- >>> pitch
- tensor(1.5708)
- >>> yaw
- tensor(2.2143)
- """
- return euler_from_quaternion(self.w, self.x, self.y, self.z)
- @classmethod
- def from_axis_angle(cls, axis_angle: Tensor) -> "Quaternion":
- """Create a quaternion from axis-angle representation.
- Args:
- axis_angle: rotation vector of shape :math:`(B, 3)`.
- Example:
- >>> axis_angle = torch.tensor([[1., 0., 0.]])
- >>> q = Quaternion.from_axis_angle(axis_angle)
- >>> q.data
- tensor([[0.8776, 0.4794, 0.0000, 0.0000]])
- """
- return cls(axis_angle_to_quaternion(axis_angle))
- def to_axis_angle(self) -> Tensor:
- """Convert the quaternion to an axis-angle representation.
- Example:
- >>> q = Quaternion.identity()
- >>> axis_angle = q.to_axis_angle()
- >>> axis_angle
- tensor([0., 0., 0.])
- """
- return quaternion_to_axis_angle(self.data)
- @classmethod
- def identity(
- cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None
- ) -> "Quaternion":
- """Create a quaternion representing an identity rotation.
- Args:
- batch_size: the batch size of the underlying data.
- device: device to place the result on.
- dtype: dtype of the result.
- Example:
- >>> q = Quaternion.identity()
- >>> q.data
- tensor([1., 0., 0., 0.])
- """
- data = tensor([1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype)
- if batch_size is not None:
- data = data.repeat(batch_size, 1)
- return cls(data)
- @classmethod
- def from_coeffs(cls, w: float, x: float, y: float, z: float) -> "Quaternion":
- """Create a quaternion from the data coefficients.
- Args:
- w: a float representing the :math:`q_w` component.
- x: a float representing the :math:`q_x` component.
- y: a float representing the :math:`q_y` component.
- z: a float representing the :math:`q_z` component.
- Example:
- >>> q = Quaternion.from_coeffs(1., 0., 0., 0.)
- >>> q.data
- tensor([1., 0., 0., 0.])
- """
- return cls(tensor([w, x, y, z]))
- # TODO: update signature
- # def random(cls, shape: Optional[List] = None, device = None, dtype = None) -> 'Quaternion':
- @classmethod
- def random(
- cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None
- ) -> "Quaternion":
- """Create a random unit quaternion of shape :math:`(B, 4)`.
- Uniformly distributed across the rotation space as per: http://planning.cs.uiuc.edu/node198.html
- Args:
- batch_size: the batch size of the underlying data.
- device: device to place the result on.
- dtype: dtype of the result.
- Example:
- >>> q = Quaternion.random()
- >>> q = Quaternion.random(batch_size=2)
- """
- rand_shape = (batch_size,) if batch_size is not None else ()
- r1, r2, r3 = rand((3, *rand_shape), device=device, dtype=dtype)
- q1 = (1.0 - r1).sqrt() * ((2 * pi * r2).sin())
- q2 = (1.0 - r1).sqrt() * ((2 * pi * r2).cos())
- q3 = r1.sqrt() * (2 * pi * r3).sin()
- q4 = r1.sqrt() * (2 * pi * r3).cos()
- return cls(stack((q1, q2, q3, q4), -1))
- def slerp(self, q1: "Quaternion", t: float) -> "Quaternion":
- """Return a unit quaternion spherically interpolated between quaternions self.q and q1.
- See more: https://en.wikipedia.org/wiki/Slerp
- Args:
- q1: second quaternion to be interpolated between.
- t: interpolation ratio, range [0-1]
- Example:
- >>> q0 = Quaternion.identity()
- >>> q1 = Quaternion(torch.tensor([1., .5, 0., 0.]))
- >>> q2 = q0.slerp(q1, .3)
- """
- KORNIA_CHECK_TYPE(q1, Quaternion)
- q0 = self.normalize()
- q1 = q1.normalize()
- return q0 * (q0.inv() * q1) ** t
- def norm(self, keepdim: bool = False) -> Tensor:
- """Compute the norm (magnitude) of the quaternion.
- Args:
- keepdim: whether to retain the last dimension.
- Returns:
- The norm of the quaternion(s) as a tensor.
- Example:
- >>> q = Quaternion.identity()
- >>> q.norm()
- tensor(1.)
- """
- # p==2, dim|axis==-1, keepdim
- return self.data.norm(2, -1, keepdim)
- def normalize(self) -> "Quaternion":
- """Return a normalized (unit) quaternion.
- Returns:
- The normalized quaternion.
- Example:
- >>> q = Quaternion(tensor([2., 1., 0., 0.]))
- >>> q_norm = q.normalize()
- """
- return Quaternion(normalize_quaternion(self.data))
- def conj(self) -> "Quaternion":
- """Compute the conjugate of the quaternion.
- Returns:
- The conjugate quaternion, with the vector part negated.
- Example:
- >>> q = Quaternion(tensor([1., 2., 3., 4.]))
- >>> q_conj = q.conj()
- """
- return Quaternion(concatenate((self.real[..., None], -self.vec), -1))
- def inv(self) -> "Quaternion":
- """Compute the inverse of the quaternion.
- Returns:
- The inverse quaternion.
- Example:
- >>> q = Quaternion.identity()
- >>> q_inv = q.inv()
- """
- return self.conj() / self.squared_norm()
- def squared_norm(self) -> Tensor:
- """Compute the squared norm (magnitude) of the quaternion.
- Returns:
- The squared norm of the quaternion(s) as a tensor.
- Example:
- >>> q = Quaternion.identity()
- >>> q.squared_norm()
- tensor(1.)
- """
- return batched_dot_product(self.vec, self.vec) + self.real**2
- def average_quaternions(Q: "Quaternion", w: Optional[torch.Tensor] = None) -> "Quaternion":
- """Compute (weighted) average of multiple quaternions.
- Args:
- Q (Quaternion): quaternion object containing data of shape (M, 4).
- w (torch.Tensor, optional): Weights of shape (M,). If None, uniform weights are used.
- Returns:
- Quaternion: averaged quaternion (shape (4,)), wrapped back in the Quaternion class.
- """
- data = Q.data
- KORNIA_CHECK_TYPE(Q, Quaternion)
- M = data.shape[0]
- if w is None:
- A = (data.T @ data) / M
- else:
- w = w.to(data.device, dtype=data.dtype)
- if w.numel() != M:
- raise ValueError(f"weights length {w.numel()} must match number of quaternions {M}")
- w = w / w.sum()
- A = data.T @ torch.diag(w) @ data
- eigenvalues, eigenvectors = torch.linalg.eigh(A)
- q_avg = eigenvectors[:, torch.argmax(eigenvalues)]
- q_avg = q_avg / q_avg.norm()
- return Quaternion(q_avg.unsqueeze(0))
|