| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- # kornia.geometry.so3 module inspired by Sophus-sympy.
- # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/so3.py
- from __future__ import annotations
- from typing import Optional
- import torch
- from kornia.core import Device, Dtype, Module, Tensor, concatenate, eye, stack, where, zeros, zeros_like
- from kornia.core.check import KORNIA_CHECK_TYPE
- from kornia.geometry.conversions import vector_to_skew_symmetric_matrix
- from kornia.geometry.linalg import batched_dot_product
- from kornia.geometry.quaternion import Quaternion
- from kornia.geometry.vector import Vector3
- class So3(Module):
- r"""Base class to represent the So3 group.
- The SO(3) is the group of all rotations about the origin of three-dimensional Euclidean space
- :math:`R^3` under the operation of composition.
- See more: https://en.wikipedia.org/wiki/3D_rotation_group
- We internally represent the rotation by a unit quaternion.
- Example:
- >>> q = Quaternion.identity()
- >>> s = So3(q)
- >>> s.q
- tensor([1., 0., 0., 0.])
- """
- def __init__(self, q: Quaternion) -> None:
- """Construct the base class.
- Internally represented by a unit quaternion `q`.
- Args:
- q: Quaternion with the shape of :math:`(B, 4)`.
- Example:
- >>> data = torch.ones((2, 4))
- >>> q = Quaternion(data)
- >>> So3(q)
- tensor([[1., 1., 1., 1.],
- [1., 1., 1., 1.]])
- """
- super().__init__()
- KORNIA_CHECK_TYPE(q, Quaternion)
- self._q = q
- def __repr__(self) -> str:
- return f"{self.q}"
- def __getitem__(self, idx: int | slice) -> So3:
- return So3(self._q[idx])
- def __mul__(self, right: So3) -> So3:
- """Compose two So3 transformations.
- Args:
- right: the other So3 transformation.
- Return:
- The resulting So3 transformation.
- """
- # https://github.com/strasdat/Sophus/blob/master/sympy/sophus/so3.py#L98
- if isinstance(right, So3):
- return So3(self.q * right.q)
- elif isinstance(right, (Tensor, Vector3)):
- # KORNIA_CHECK_SHAPE(right, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
- w = zeros(*right.shape[:-1], 1, device=right.device, dtype=right.dtype)
- quat = Quaternion(concatenate((w, right.data), -1))
- out = (self.q * quat * self.q.conj()).vec
- if isinstance(right, Tensor):
- return out
- elif isinstance(right, Vector3):
- return Vector3(out)
- else:
- raise TypeError(f"Not So3 or Tensor type. Got: {type(right)}")
- @property
- def q(self) -> Quaternion:
- """Return the underlying data with shape :math:`(B,4)`."""
- return self._q
- @staticmethod
- def exp(v: Tensor) -> So3:
- """Convert elements of lie algebra to elements of lie group.
- See more: https://vision.in.tum.de/_media/members/demmeln/nurlanov2021so3log.pdf
- Args:
- v: vector of shape :math:`(B,3)`.
- Example:
- >>> v = torch.zeros((2, 3))
- >>> s = So3.exp(v)
- >>> s
- tensor([[1., 0., 0., 0.],
- [1., 0., 0., 0.]])
- """
- # KORNIA_CHECK_SHAPE(v, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
- theta = v.norm(dim=-1, keepdim=True)
- theta_half = 0.5 * theta
- w = torch.cos(theta_half)
- eps = torch.finfo(v.dtype).eps * 1e3
- small_mask = theta <= eps
- b_large = torch.sin(theta_half) / theta
- b_small = 0.5 - (theta * theta) / 48.0
- b = torch.where(small_mask, b_small, b_large)
- xyz = b * v
- q = torch.cat((w, xyz), dim=-1)
- return So3(Quaternion(q))
- def log(self) -> Tensor:
- """Convert elements of lie group to elements of lie algebra.
- Example:
- >>> data = torch.ones((2, 4))
- >>> q = Quaternion(data)
- >>> So3(q).log()
- tensor([[0., 0., 0.],
- [0., 0., 0.]])
- """
- theta = batched_dot_product(self.q.vec, self.q.vec).sqrt()
- # NOTE: this differs from https://github.com/strasdat/Sophus/blob/master/sympy/sophus/so3.py#L33
- omega = where(
- theta[..., None] != 0,
- 2 * self.q.real[..., None].acos() * self.q.vec / theta[..., None],
- 2 * self.q.vec / self.q.real[..., None],
- )
- return omega
- @staticmethod
- def hat(v: Vector3 | Tensor) -> Tensor:
- """Convert elements from vector space to lie algebra. Returns matrix of shape :math:`(B,3,3)`.
- Args:
- v: Vector3 or tensor of shape :math:`(B,3)`.
- Example:
- >>> v = torch.ones((1,3))
- >>> m = So3.hat(v)
- >>> m
- tensor([[[ 0., -1., 1.],
- [ 1., 0., -1.],
- [-1., 1., 0.]]])
- """
- # KORNIA_CHECK_SHAPE(v, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
- if isinstance(v, Tensor):
- # TODO: Figure out why mypy think `v` can be a Vector3 which didn't allow ellipsis on index
- a, b, c = v[..., 0], v[..., 1], v[..., 2] # type: ignore[index]
- else:
- a, b, c = v.x, v.y, v.z
- z = zeros_like(a)
- row0 = stack((z, -c, b), -1)
- row1 = stack((c, z, -a), -1)
- row2 = stack((-b, a, z), -1)
- return stack((row0, row1, row2), -2)
- @staticmethod
- def vee(omega: Tensor) -> Tensor:
- r"""Convert elements from lie algebra to vector space. Returns vector of shape :math:`(B,3)`.
- .. math::
- omega = \begin{bmatrix} 0 & -c & b \\
- c & 0 & -a \\
- -b & a & 0\end{bmatrix}
- Args:
- omega: 3x3-matrix representing lie algebra.
- Example:
- >>> v = torch.ones((1,3))
- >>> omega = So3.hat(v)
- >>> So3.vee(omega)
- tensor([[1., 1., 1.]])
- """
- # KORNIA_CHECK_SHAPE(omega, ["B", "3", "3"]) # FIXME: resolve shape bugs. @edgarriba
- a, b, c = omega[..., 2, 1], omega[..., 0, 2], omega[..., 1, 0]
- return stack((a, b, c), -1)
- def matrix(self) -> Tensor:
- r"""Convert the quaternion to a rotation matrix of shape :math:`(B,3,3)`.
- The matrix is of the form:
- .. math::
- \begin{bmatrix} 1-2y^2-2z^2 & 2xy-2zw & 2xy+2yw \\
- 2xy+2zw & 1-2x^2-2z^2 & 2yz-2xw \\
- 2xz-2yw & 2yz+2xw & 1-2x^2-2y^2\end{bmatrix}
- Example:
- >>> s = So3.identity()
- >>> m = s.matrix()
- >>> m
- tensor([[1., 0., 0.],
- [0., 1., 0.],
- [0., 0., 1.]])
- """
- w = self.q.w[..., None]
- x, y, z = self.q.x[..., None], self.q.y[..., None], self.q.z[..., None]
- q0 = 1 - 2 * y**2 - 2 * z**2
- q1 = 2 * x * y - 2 * z * w
- q2 = 2 * x * z + 2 * y * w
- row0 = concatenate((q0, q1, q2), -1)
- q0 = 2 * x * y + 2 * z * w
- q1 = 1 - 2 * x**2 - 2 * z**2
- q2 = 2 * y * z - 2 * x * w
- row1 = concatenate((q0, q1, q2), -1)
- q0 = 2 * x * z - 2 * y * w
- q1 = 2 * y * z + 2 * x * w
- q2 = 1 - 2 * x**2 - 2 * y**2
- row2 = concatenate((q0, q1, q2), -1)
- return stack((row0, row1, row2), -2)
- @classmethod
- def from_matrix(cls, matrix: Tensor) -> So3:
- """Create So3 from a rotation matrix.
- Args:
- matrix: the rotation matrix to convert of shape :math:`(B,3,3)`.
- Example:
- >>> m = torch.eye(3)
- >>> s = So3.from_matrix(m)
- >>> s
- tensor([1., 0., 0., 0.])
- """
- return cls(Quaternion.from_matrix(matrix))
- @classmethod
- def from_wxyz(cls, wxyz: Tensor) -> So3:
- """Create So3 from a tensor representing a quaternion.
- Args:
- wxyz: the quaternion to convert of shape :math:`(B,4)`.
- Example:
- >>> q = torch.tensor([1., 0., 0., 0.])
- >>> s = So3.from_wxyz(q)
- >>> s
- tensor([1., 0., 0., 0.])
- """
- # KORNIA_CHECK_SHAPE(wxyz, ["B", "4"]) # FIXME: resolve shape bugs. @edgarriba
- return cls(Quaternion(wxyz))
- @classmethod
- def identity(
- cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Optional[Dtype] = None
- ) -> So3:
- """Create a So3 group representing an identity rotation.
- Args:
- batch_size: the batch size of the underlying data.
- device: device to place the result on.
- dtype: dtype of the result.
- Example:
- >>> s = So3.identity()
- >>> s
- tensor([1., 0., 0., 0.])
- >>> s = So3.identity(batch_size=2)
- >>> s
- tensor([[1., 0., 0., 0.],
- [1., 0., 0., 0.]])
- """
- return cls(Quaternion.identity(batch_size, device, dtype))
- def inverse(self) -> So3:
- """Return the inverse transformation.
- Example:
- >>> s = So3.identity()
- >>> s.inverse()
- tensor([1., -0., -0., -0.])
- """
- return So3(self.q.conj())
- @classmethod
- def random(
- cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Optional[Dtype] = None
- ) -> So3:
- """Create a So3 group representing a random rotation.
- Args:
- batch_size: the batch size of the underlying data.
- device: device to place the result on.
- dtype: dtype of the result.
- Example:
- >>> s = So3.random()
- >>> s = So3.random(batch_size=3)
- """
- return cls(Quaternion.random(batch_size, device, dtype))
- @classmethod
- def rot_x(cls, x: Tensor) -> So3:
- """Construct a x-axis rotation.
- Args:
- x: the x-axis rotation angle.
- """
- zs = zeros_like(x)
- return cls.exp(stack((x, zs, zs), -1))
- @classmethod
- def rot_y(cls, y: Tensor) -> So3:
- """Construct a z-axis rotation.
- Args:
- y: the y-axis rotation angle.
- """
- zs = zeros_like(y)
- return cls.exp(stack((zs, y, zs), -1))
- @classmethod
- def rot_z(cls, z: Tensor) -> So3:
- """Construct a z-axis rotation.
- Args:
- z: the z-axis rotation angle.
- """
- zs = zeros_like(z)
- return cls.exp(stack((zs, zs, z), -1))
- def adjoint(self) -> Tensor:
- """Return the adjoint matrix of shape :math:`(B, 3, 3)`.
- Example:
- >>> s = So3.identity()
- >>> s.adjoint()
- tensor([[1., 0., 0.],
- [0., 1., 0.],
- [0., 0., 1.]])
- """
- return self.matrix()
- @staticmethod
- def right_jacobian(vec: Tensor) -> Tensor:
- """Compute the right Jacobian of So3.
- Args:
- vec: the input point of shape :math:`(B, 3)`.
- Example:
- >>> vec = torch.tensor([1., 2., 3.])
- >>> So3.right_jacobian(vec)
- tensor([[-0.0687, 0.5556, -0.0141],
- [-0.2267, 0.1779, 0.6236],
- [ 0.5074, 0.3629, 0.5890]])
- """
- # KORNIA_CHECK_SHAPE(vec, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
- R_skew = vector_to_skew_symmetric_matrix(vec)
- theta = vec.norm(dim=-1, keepdim=True)[..., None]
- I = eye(3, device=vec.device, dtype=vec.dtype) # noqa: E741
- Jr = I - ((1 - theta.cos()) / theta**2) * R_skew + ((theta - theta.sin()) / theta**3) * (R_skew @ R_skew)
- return Jr
- @staticmethod
- def Jr(vec: Tensor) -> Tensor:
- """Alias for right jacobian.
- Args:
- vec: the input point of shape :math:`(B, 3)`.
- """
- return So3.right_jacobian(vec)
- @staticmethod
- def left_jacobian(vec: Tensor) -> Tensor:
- """Compute the left Jacobian of So3.
- Args:
- vec: the input point of shape :math:`(B, 3)`.
- Example:
- >>> vec = torch.tensor([1., 2., 3.])
- >>> So3.left_jacobian(vec)
- tensor([[-0.0687, -0.2267, 0.5074],
- [ 0.5556, 0.1779, 0.3629],
- [-0.0141, 0.6236, 0.5890]])
- """
- # KORNIA_CHECK_SHAPE(vec, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
- R_skew = vector_to_skew_symmetric_matrix(vec)
- theta = vec.norm(dim=-1, keepdim=True)[..., None]
- I = eye(3, device=vec.device, dtype=vec.dtype) # noqa: E741
- Jl = I + ((1 - theta.cos()) / theta**2) * R_skew + ((theta - theta.sin()) / theta**3) * (R_skew @ R_skew)
- return Jl
- @staticmethod
- def Jl(vec: Tensor) -> Tensor:
- """Alias for left jacobian.
- Args:
- vec: the input point of shape :math:`(B, 3)`.
- """
- return So3.left_jacobian(vec)
|