linalg.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import torch
  19. from kornia.core import Tensor
  20. from kornia.core.check import KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
  21. from kornia.geometry.conversions import convert_points_from_homogeneous, convert_points_to_homogeneous
  22. __all__ = [
  23. "batched_dot_product",
  24. "batched_squared_norm",
  25. "compose_transformations",
  26. "euclidean_distance",
  27. "inverse_transformation",
  28. "point_line_distance",
  29. "relative_transformation",
  30. "squared_norm",
  31. "transform_points",
  32. ]
  33. def compose_transformations(trans_01: Tensor, trans_12: Tensor) -> Tensor:
  34. r"""Compose two homogeneous transformations.
  35. .. math::
  36. T_0^{2} = \begin{bmatrix} R_0^1 R_1^{2} & R_0^{1} t_1^{2} + t_0^{1} \
  37. \\mathbf{0} & 1\end{bmatrix}
  38. Args:
  39. trans_01: tensor with the homogeneous transformation from
  40. a reference frame 1 respect to a frame 0. The tensor has must have a
  41. shape of :math:`(N, 4, 4)` or :math:`(4, 4)`.
  42. trans_12: tensor with the homogeneous transformation from
  43. a reference frame 2 respect to a frame 1. The tensor has must have a
  44. shape of :math:`(N, 4, 4)` or :math:`(4, 4)`.
  45. Returns:
  46. the transformation between the two frames with shape :math:`(N, 4, 4)` or :math:`(4, 4)`.
  47. Example::
  48. >>> trans_01 = torch.eye(4) # 4x4
  49. >>> trans_12 = torch.eye(4) # 4x4
  50. >>> trans_02 = compose_transformations(trans_01, trans_12) # 4x4
  51. """
  52. KORNIA_CHECK_IS_TENSOR(trans_01)
  53. KORNIA_CHECK_IS_TENSOR(trans_12)
  54. if not ((trans_01.dim() in (2, 3)) and (trans_01.shape[-2:] == (4, 4))):
  55. raise ValueError(f"Input trans_01 must be a of the shape Nx4x4 or 4x4. Got {trans_01.shape}")
  56. if not ((trans_12.dim() in (2, 3)) and (trans_12.shape[-2:] == (4, 4))):
  57. raise ValueError(f"Input trans_12 must be a of the shape Nx4x4 or 4x4. Got {trans_12.shape}")
  58. if trans_01.dim() != trans_12.dim():
  59. raise ValueError(f"Input number of dims must match. Got {trans_01.dim()} and {trans_12.dim()}")
  60. # unpack input data
  61. rmat_01 = trans_01[..., :3, :3]
  62. rmat_12 = trans_12[..., :3, :3]
  63. tvec_01 = trans_01[..., :3, 3:]
  64. tvec_12 = trans_12[..., :3, 3:]
  65. # compute the actual transforms composition
  66. rmat_02 = torch.matmul(rmat_01, rmat_12)
  67. tvec_02 = torch.matmul(rmat_01, tvec_12) + tvec_01
  68. trans_02 = trans_01.new_zeros(trans_01.shape)
  69. trans_02[..., :3, :3] = rmat_02
  70. trans_02[..., :3, 3:] = tvec_02
  71. trans_02[..., 3, 3] = 1.0
  72. return trans_02
  73. def inverse_transformation(trans_12: Tensor) -> Tensor:
  74. r"""Invert a 4x4 homogeneous transformation.
  75. :math:`T_1^{2} = \begin{bmatrix} R_1 & t_1 \\ \mathbf{0} & 1 \end{bmatrix}`
  76. The inverse transformation is computed as follows:
  77. .. math::
  78. T_2^{1} = (T_1^{2})^{-1} = \begin{bmatrix} R_1^T & -R_1^T t_1 \\
  79. \mathbf{0} & 1\end{bmatrix}
  80. Args:
  81. trans_12: transformation tensor of shape :math:`(N, 4, 4)` or :math:`(4, 4)`.
  82. Returns:
  83. tensor with inverted transformations with shape :math:`(N, 4, 4)` or :math:`(4, 4)`.
  84. Example:
  85. >>> trans_12 = torch.rand(1, 4, 4) # Nx4x4
  86. >>> trans_21 = inverse_transformation(trans_12) # Nx4x4
  87. """
  88. KORNIA_CHECK_IS_TENSOR(trans_12)
  89. if not ((trans_12.dim() in (2, 3)) and (trans_12.shape[-2:] == (4, 4))):
  90. raise ValueError(f"Input size must be a Nx4x4 or 4x4. Got {trans_12.shape}")
  91. # unpack input tensor
  92. rmat_12 = trans_12[..., :3, :3] # Nx3x3 or 3x3
  93. tvec_12 = trans_12[..., :3, 3:4] # Nx3x1 or 3x1
  94. # compute the actual inverse
  95. rmat_21 = rmat_12.transpose(-1, -2)
  96. tvec_21 = torch.matmul(-rmat_21, tvec_12)
  97. # pack to output tensor
  98. trans_21 = trans_12.new_zeros(trans_12.shape)
  99. trans_21[..., :3, :3].copy_(rmat_21)
  100. trans_21[..., :3, 3:4].copy_(tvec_21)
  101. trans_21[..., 3, 3] = 1.0
  102. return trans_21
  103. def relative_transformation(trans_01: Tensor, trans_02: Tensor) -> Tensor:
  104. r"""Compute the relative homogeneous transformation from a reference transformation.
  105. :math:`T_1^{0} = \begin{bmatrix} R_1 & t_1 \\ \mathbf{0} & 1 \end{bmatrix}` to destination :math:`T_2^{0} =
  106. \begin{bmatrix} R_2 & t_2 \\ \mathbf{0} & 1 \end{bmatrix}`.
  107. The relative transformation is computed as follows:
  108. .. math::
  109. T_1^{2} = (T_0^{1})^{-1} \cdot T_0^{2}
  110. Args:
  111. trans_01: reference transformation tensor of shape :math:`(N, 4, 4)` or :math:`(4, 4)`.
  112. trans_02: destination transformation tensor of shape :math:`(N, 4, 4)` or :math:`(4, 4)`.
  113. Returns:
  114. the relative transformation between the transformations with shape :math:`(N, 4, 4)` or :math:`(4, 4)`.
  115. Example::
  116. >>> trans_01 = torch.eye(4) # 4x4
  117. >>> trans_02 = torch.eye(4) # 4x4
  118. >>> trans_12 = relative_transformation(trans_01, trans_02) # 4x4
  119. """
  120. KORNIA_CHECK_IS_TENSOR(trans_01)
  121. KORNIA_CHECK_IS_TENSOR(trans_02)
  122. if not ((trans_01.dim() in (2, 3)) and (trans_01.shape[-2:] == (4, 4))):
  123. raise ValueError(f"Input must be a of the shape Nx4x4 or 4x4. Got {trans_01.shape}")
  124. if not ((trans_02.dim() in (2, 3)) and (trans_02.shape[-2:] == (4, 4))):
  125. raise ValueError(f"Input must be a of the shape Nx4x4 or 4x4. Got {trans_02.shape}")
  126. if not trans_01.dim() == trans_02.dim():
  127. raise ValueError(f"Input number of dims must match. Got {trans_01.dim()} and {trans_02.dim()}")
  128. rmat_01 = trans_01[..., :3, :3]
  129. tvec_01 = trans_01[..., :3, 3:4]
  130. rmat_02 = trans_02[..., :3, :3]
  131. tvec_02 = trans_02[..., :3, 3:4]
  132. rmat_10 = rmat_01.transpose(-1, -2)
  133. rmat_12 = torch.matmul(rmat_10, rmat_02)
  134. tvec_12 = torch.matmul(rmat_10, tvec_02 - tvec_01)
  135. trans_12 = torch.zeros_like(trans_01)
  136. trans_12[..., :3, :3] = rmat_12
  137. trans_12[..., :3, 3:4] = tvec_12
  138. trans_12[..., 3, 3] = 1.0
  139. return trans_12
  140. def transform_points(trans_01: Tensor, points_1: Tensor) -> Tensor:
  141. r"""Apply transformations to a set of points.
  142. Args:
  143. trans_01: tensor for transformations of shape
  144. :math:`(B, D+1, D+1)`.
  145. points_1: tensor of points of shape :math:`(B, N, D)`.
  146. Returns:
  147. a tensor of N-dimensional points.
  148. Shape:
  149. - Output: :math:`(B, N, D)`
  150. Examples:
  151. >>> points_1 = torch.rand(2, 4, 3) # BxNx3
  152. >>> trans_01 = torch.eye(4).view(1, 4, 4) # Bx4x4
  153. >>> points_0 = transform_points(trans_01, points_1) # BxNx3
  154. """
  155. KORNIA_CHECK_IS_TENSOR(trans_01)
  156. KORNIA_CHECK_IS_TENSOR(points_1)
  157. if not trans_01.shape[0] == points_1.shape[0] and trans_01.shape[0] != 1:
  158. raise ValueError(
  159. f"Input batch size must be the same for both tensors or 1. Got {trans_01.shape} and {points_1.shape}"
  160. )
  161. if not trans_01.shape[-1] == (points_1.shape[-1] + 1):
  162. raise ValueError(f"Last input dimensions must differ by one unit Got{trans_01} and {points_1}")
  163. # We reshape to BxNxD in case we get more dimensions, e.g., MxBxNxD
  164. shape_inp = list(points_1.shape)
  165. points_1 = points_1.reshape(-1, points_1.shape[-2], points_1.shape[-1])
  166. trans_01 = trans_01.reshape(-1, trans_01.shape[-2], trans_01.shape[-1])
  167. # We expand trans_01 to match the dimensions needed for bmm. repeats input division is cast
  168. # to integer so onnx doesn't record the value as a tensor and get a device mismatch
  169. trans_01 = torch.repeat_interleave(trans_01, repeats=int(points_1.shape[0] // trans_01.shape[0]), dim=0)
  170. # to homogeneous
  171. points_1_h = convert_points_to_homogeneous(points_1) # BxNxD+1
  172. # transform coordinates
  173. points_0_h = torch.bmm(points_1_h, trans_01.permute(0, 2, 1))
  174. points_0_h = torch.squeeze(points_0_h, dim=-1)
  175. # to euclidean
  176. points_0 = convert_points_from_homogeneous(points_0_h) # BxNxD
  177. # reshape to the input shape
  178. shape_inp[-2] = points_0.shape[-2]
  179. shape_inp[-1] = points_0.shape[-1]
  180. points_0 = points_0.reshape(shape_inp)
  181. return points_0
  182. def point_line_distance(point: Tensor, line: Tensor, eps: float = 1e-9) -> Tensor:
  183. r"""Return the distance from points to lines.
  184. Args:
  185. point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`.
  186. line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`.
  187. eps: Small constant for safe sqrt.
  188. Returns:
  189. the computed distance with shape :math:`(*, N)`.
  190. """
  191. KORNIA_CHECK_IS_TENSOR(point)
  192. KORNIA_CHECK_IS_TENSOR(line)
  193. if point.shape[-1] not in (2, 3):
  194. raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}")
  195. if line.shape[-1] != 3:
  196. raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
  197. # Using in-place operations to improve performance
  198. numerator = line[..., 0] * point[..., 0]
  199. numerator += line[..., 1] * point[..., 1]
  200. numerator += line[..., 2]
  201. numerator.abs_()
  202. # Avoid computing norm multiple times by saving its value
  203. denom_norm = (line[..., 0].square() + line[..., 1].square()).sqrt()
  204. return numerator / (denom_norm + eps)
  205. def batched_dot_product(x: Tensor, y: Tensor, keepdim: bool = False) -> Tensor:
  206. """Return a batched version of .dot()."""
  207. KORNIA_CHECK_SHAPE(x, ["*", "N"])
  208. KORNIA_CHECK_SHAPE(y, ["*", "N"])
  209. return (x * y).sum(-1, keepdim)
  210. def batched_squared_norm(x: Tensor, keepdim: bool = False) -> Tensor:
  211. """Return the squared norm of a vector."""
  212. return batched_dot_product(x, x, keepdim)
  213. def euclidean_distance(x: Tensor, y: Tensor, keepdim: bool = False, eps: float = 1e-6) -> Tensor:
  214. """Compute the Euclidean distance between two set of n-dimensional points.
  215. More: https://en.wikipedia.org/wiki/Euclidean_distance
  216. Args:
  217. x: first set of points of shape :math:`(*, N)`.
  218. y: second set of points of shape :math:`(*, N)`.
  219. keepdim: whether to keep the dimension after reduction.
  220. eps: small value to have numerical stability.
  221. """
  222. KORNIA_CHECK_SHAPE(x, ["*", "N"])
  223. KORNIA_CHECK_SHAPE(y, ["*", "N"])
  224. return (x - y).pow(2).sum(dim=-1, keepdim=keepdim).add_(eps).sqrt_()
  225. # aliases
  226. squared_norm = batched_squared_norm
  227. # TODO:
  228. # - project_points: from opencv