line.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # kornia.geometry.line module inspired by Eigen::geometry::ParametrizedLine
  18. # https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/ParametrizedLine.h
  19. from typing import Iterator, Optional, Tuple, Union
  20. import torch
  21. from kornia.core import Module, Parameter, Tensor, normalize, where
  22. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
  23. from kornia.geometry.linalg import batched_dot_product
  24. from kornia.geometry.plane import Hyperplane
  25. from kornia.utils.helpers import _torch_svd_cast
  26. __all__ = ["ParametrizedLine", "fit_line"]
  27. class ParametrizedLine(Module):
  28. """Class that describes a parametrize line.
  29. A parametrized line is defined by an origin point :math:`o` and a unit
  30. direction vector :math:`d` such that the line corresponds to the set
  31. .. math::
  32. l(t) = o + t * d
  33. """
  34. def __init__(self, origin: Tensor, direction: Tensor) -> None:
  35. """Initialize a parametrized line of direction and origin.
  36. Args:
  37. origin: any point on the line of any dimension.
  38. direction: the normalized vector direction of any dimension.
  39. Example:
  40. >>> o = torch.tensor([0.0, 0.0])
  41. >>> d = torch.tensor([1.0, 1.0])
  42. >>> l = ParametrizedLine(o, d)
  43. """
  44. super().__init__()
  45. self._origin = Parameter(origin)
  46. self._direction = Parameter(direction)
  47. def __str__(self) -> str:
  48. return f"Origin: {self.origin}\nDirection: {self.direction}"
  49. def __repr__(self) -> str:
  50. return str(self)
  51. def __getitem__(self, idx: int) -> Tensor:
  52. return self.origin if idx == 0 else self.direction
  53. def __iter__(self) -> Iterator[Tensor]:
  54. yield from (self.origin, self.direction)
  55. @property
  56. def origin(self) -> Tensor:
  57. """Return the line origin point."""
  58. return self._origin
  59. @property
  60. def direction(self) -> Tensor:
  61. """Return the line direction vector."""
  62. return self._direction
  63. def dim(self) -> int:
  64. """Return the dimension in which the line holds."""
  65. return self.direction.shape[-1]
  66. @classmethod
  67. def through(cls, p0: Tensor, p1: Tensor) -> "ParametrizedLine":
  68. """Construct a parametrized line going from a point :math:`p0` to :math:`p1`.
  69. Args:
  70. p0: tensor with first point :math:`(B, D)` where `D` is the point dimension.
  71. p1: tensor with second point :math:`(B, D)` where `D` is the point dimension.
  72. Example:
  73. >>> p0 = torch.tensor([0.0, 0.0])
  74. >>> p1 = torch.tensor([1.0, 1.0])
  75. >>> l = ParametrizedLine.through(p0, p1)
  76. """
  77. return ParametrizedLine(p0, normalize((p1 - p0), p=2, dim=-1))
  78. def point_at(self, t: Union[float, Tensor]) -> Tensor:
  79. """Get the point at :math:`t` along this line.
  80. Args:
  81. t: step along the line.
  82. Return:
  83. tensor with the point.
  84. Example:
  85. >>> p0 = torch.tensor([0.0, 0.0])
  86. >>> p1 = torch.tensor([1.0, 1.0])
  87. >>> l = ParametrizedLine.through(p0, p1)
  88. >>> p2 = l.point_at(0.1)
  89. """
  90. return self.origin + self.direction * t
  91. def projection(self, point: Tensor) -> Tensor:
  92. """Return the projection of a point onto the line.
  93. Args:
  94. point: the point to be projected.
  95. """
  96. return self.origin + (self.direction @ (point - self.origin)) * self.direction
  97. def squared_distance(self, point: Tensor) -> Tensor:
  98. """Return the squared distance of a point to its projection onte the line.
  99. Args:
  100. point: the point to calculate the distance onto the line.
  101. """
  102. d = point - self.origin
  103. proj = torch.sum(d * self.direction, dim=-1)
  104. sq_norm_d = torch.sum(d * d, dim=-1)
  105. return sq_norm_d - proj * proj
  106. def distance(self, point: Tensor) -> Tensor:
  107. """Return the distance of a point to its projections onto the line.
  108. Args:
  109. point: the point to calculate the distance into the line.
  110. """
  111. return self.squared_distance(point).sqrt()
  112. # TODO(edgar) implement the following:
  113. # - intersection
  114. # - intersection_parameter
  115. # - intersection_point
  116. # TODO: add tests, and possibly return a mask
  117. def intersect(self, plane: Hyperplane, eps: float = 1e-6) -> Tuple[Tensor, Tensor]:
  118. """Return the intersection point between the line and a given plane.
  119. Args:
  120. plane: the plane to compute the intersection point.
  121. eps: epsilon for numerical stability.
  122. Return:
  123. - the lambda value used to compute the look at point.
  124. - the intersected point.
  125. """
  126. dot_prod = batched_dot_product(plane.normal.data, self.direction.data)
  127. dot_prod_mask = dot_prod.abs() >= eps
  128. # TODO: add check for dot product
  129. res_lambda = where(
  130. dot_prod_mask,
  131. -(plane.offset + batched_dot_product(plane.normal.data, self.origin.data)) / dot_prod,
  132. torch.empty_like(dot_prod),
  133. )
  134. res_point = self.point_at(res_lambda)
  135. return res_lambda, res_point
  136. def _fit_line_ols_2d(points: Tensor) -> ParametrizedLine:
  137. x = points[..., 0]
  138. y = points[..., 1]
  139. x_mean = x.mean(dim=-1, keepdim=True)
  140. y_mean = y.mean(dim=-1, keepdim=True)
  141. dx = x - x_mean
  142. dy = y - y_mean
  143. denom = (dx * dx).sum(dim=-1, keepdim=True) # (B, 1)
  144. slope = torch.where(denom > 1e-8, (dx * dy).sum(dim=-1, keepdim=True) / denom, torch.zeros_like(denom))
  145. # For vertical lines, fallback to [0,1] direction
  146. direction = torch.where(
  147. denom > 1e-8,
  148. torch.cat([torch.ones_like(slope), slope], dim=-1),
  149. torch.tensor([0.0, 1.0], device=points.device).expand(points.shape[0], 2),
  150. )
  151. direction = direction / direction.norm(dim=-1, keepdim=True)
  152. origin = torch.cat([x_mean, y_mean], dim=-1)
  153. return ParametrizedLine(origin, direction)
  154. def _fit_line_weighted_ols_2d(points: Tensor, weights: Tensor) -> ParametrizedLine:
  155. x = points[..., 0] # (B, N)
  156. y = points[..., 1] # (B, N)
  157. w_sum = weights.sum(dim=-1, keepdim=True) # (B, 1)
  158. x_mean = (weights * x).sum(dim=-1, keepdim=True) / w_sum # (B, 1)
  159. y_mean = (weights * y).sum(dim=-1, keepdim=True) / w_sum # (B, 1)
  160. dx = x - x_mean # (B, N)
  161. dy = y - y_mean # (B, N)
  162. weighted_dx2 = weights * dx * dx
  163. weighted_dxdy = weights * dx * dy
  164. denom = weighted_dx2.sum(dim=-1, keepdim=True) # (B, 1)
  165. slope = weighted_dxdy.sum(dim=-1, keepdim=True) / denom # (B, 1)
  166. # Replace NaNs or infs from division by zero
  167. slope = torch.where(torch.isfinite(slope), slope, torch.zeros_like(slope))
  168. # direction = normalize([1, slope]) or [0,1] if vertical
  169. is_vertical = denom <= 1e-8
  170. direction = torch.cat([torch.ones_like(slope), slope], dim=-1) # (B, 2)
  171. replacement = torch.tensor([0.0, 1.0], device=points.device, dtype=points.dtype)
  172. direction[is_vertical.squeeze(-1)] = replacement
  173. direction = direction / direction.norm(dim=-1, keepdim=True)
  174. origin = torch.cat([x_mean, y_mean], dim=-1)
  175. return ParametrizedLine(origin, direction)
  176. def fit_line(points: Tensor, weights: Optional[Tensor] = None) -> ParametrizedLine:
  177. """Fit a line from a set of points.
  178. Args:
  179. points: tensor containing a batch of sets of n-dimensional points. The expected
  180. shape of the tensor is :math:`(B, N, D)`.
  181. weights: weights to use to solve the equations system. The expected
  182. shape of the tensor is :math:`(B, N)`.
  183. Return:
  184. A tensor containing the direction of the fitted line of shape :math:`(B, D)`.
  185. Example:
  186. >>> points = torch.rand(2, 10, 3)
  187. >>> weights = torch.ones(2, 10)
  188. >>> line = fit_line(points, weights)
  189. >>> line.direction.shape
  190. torch.Size([2, 3])
  191. """
  192. KORNIA_CHECK_IS_TENSOR(points, "points must be a tensor")
  193. KORNIA_CHECK_SHAPE(points, ["B", "N", "D"])
  194. _B, _N, D = points.shape
  195. # Fast path: use OLS for unweighted 2D case
  196. if D == 2:
  197. if weights is not None:
  198. KORNIA_CHECK_IS_TENSOR(weights, "weights must be a tensor")
  199. KORNIA_CHECK_SHAPE(weights, ["B", "N"])
  200. KORNIA_CHECK(points.shape[0] == weights.shape[0])
  201. return _fit_line_weighted_ols_2d(points, weights)
  202. else:
  203. return _fit_line_ols_2d(points)
  204. mean = points.mean(-2, True)
  205. A = points - mean
  206. if weights is not None:
  207. KORNIA_CHECK_IS_TENSOR(weights, "weights must be a tensor")
  208. KORNIA_CHECK_SHAPE(weights, ["B", "N"])
  209. KORNIA_CHECK(points.shape[0] == weights.shape[0])
  210. A = A.transpose(-2, -1) @ torch.diag_embed(weights) @ A
  211. else:
  212. A = A.transpose(-2, -1) @ A
  213. # NOTE: not optimal for 2d points, but for now works for other dimensions
  214. _, _, V = _torch_svd_cast(A)
  215. V = V.transpose(-2, -1)
  216. # the first left eigenvector is the direction on the fitted line
  217. direction = V[..., 0, :] # BxD
  218. origin = mean[..., 0, :] # BxD
  219. return ParametrizedLine(origin, direction)