| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- # kornia.geometry.line module inspired by Eigen::geometry::ParametrizedLine
- # https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/ParametrizedLine.h
- from typing import Iterator, Optional, Tuple, Union
- import torch
- from kornia.core import Module, Parameter, Tensor, normalize, where
- from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
- from kornia.geometry.linalg import batched_dot_product
- from kornia.geometry.plane import Hyperplane
- from kornia.utils.helpers import _torch_svd_cast
- __all__ = ["ParametrizedLine", "fit_line"]
- class ParametrizedLine(Module):
- """Class that describes a parametrize line.
- A parametrized line is defined by an origin point :math:`o` and a unit
- direction vector :math:`d` such that the line corresponds to the set
- .. math::
- l(t) = o + t * d
- """
- def __init__(self, origin: Tensor, direction: Tensor) -> None:
- """Initialize a parametrized line of direction and origin.
- Args:
- origin: any point on the line of any dimension.
- direction: the normalized vector direction of any dimension.
- Example:
- >>> o = torch.tensor([0.0, 0.0])
- >>> d = torch.tensor([1.0, 1.0])
- >>> l = ParametrizedLine(o, d)
- """
- super().__init__()
- self._origin = Parameter(origin)
- self._direction = Parameter(direction)
- def __str__(self) -> str:
- return f"Origin: {self.origin}\nDirection: {self.direction}"
- def __repr__(self) -> str:
- return str(self)
- def __getitem__(self, idx: int) -> Tensor:
- return self.origin if idx == 0 else self.direction
- def __iter__(self) -> Iterator[Tensor]:
- yield from (self.origin, self.direction)
- @property
- def origin(self) -> Tensor:
- """Return the line origin point."""
- return self._origin
- @property
- def direction(self) -> Tensor:
- """Return the line direction vector."""
- return self._direction
- def dim(self) -> int:
- """Return the dimension in which the line holds."""
- return self.direction.shape[-1]
- @classmethod
- def through(cls, p0: Tensor, p1: Tensor) -> "ParametrizedLine":
- """Construct a parametrized line going from a point :math:`p0` to :math:`p1`.
- Args:
- p0: tensor with first point :math:`(B, D)` where `D` is the point dimension.
- p1: tensor with second point :math:`(B, D)` where `D` is the point dimension.
- Example:
- >>> p0 = torch.tensor([0.0, 0.0])
- >>> p1 = torch.tensor([1.0, 1.0])
- >>> l = ParametrizedLine.through(p0, p1)
- """
- return ParametrizedLine(p0, normalize((p1 - p0), p=2, dim=-1))
- def point_at(self, t: Union[float, Tensor]) -> Tensor:
- """Get the point at :math:`t` along this line.
- Args:
- t: step along the line.
- Return:
- tensor with the point.
- Example:
- >>> p0 = torch.tensor([0.0, 0.0])
- >>> p1 = torch.tensor([1.0, 1.0])
- >>> l = ParametrizedLine.through(p0, p1)
- >>> p2 = l.point_at(0.1)
- """
- return self.origin + self.direction * t
- def projection(self, point: Tensor) -> Tensor:
- """Return the projection of a point onto the line.
- Args:
- point: the point to be projected.
- """
- return self.origin + (self.direction @ (point - self.origin)) * self.direction
- def squared_distance(self, point: Tensor) -> Tensor:
- """Return the squared distance of a point to its projection onte the line.
- Args:
- point: the point to calculate the distance onto the line.
- """
- d = point - self.origin
- proj = torch.sum(d * self.direction, dim=-1)
- sq_norm_d = torch.sum(d * d, dim=-1)
- return sq_norm_d - proj * proj
- def distance(self, point: Tensor) -> Tensor:
- """Return the distance of a point to its projections onto the line.
- Args:
- point: the point to calculate the distance into the line.
- """
- return self.squared_distance(point).sqrt()
- # TODO(edgar) implement the following:
- # - intersection
- # - intersection_parameter
- # - intersection_point
- # TODO: add tests, and possibly return a mask
- def intersect(self, plane: Hyperplane, eps: float = 1e-6) -> Tuple[Tensor, Tensor]:
- """Return the intersection point between the line and a given plane.
- Args:
- plane: the plane to compute the intersection point.
- eps: epsilon for numerical stability.
- Return:
- - the lambda value used to compute the look at point.
- - the intersected point.
- """
- dot_prod = batched_dot_product(plane.normal.data, self.direction.data)
- dot_prod_mask = dot_prod.abs() >= eps
- # TODO: add check for dot product
- res_lambda = where(
- dot_prod_mask,
- -(plane.offset + batched_dot_product(plane.normal.data, self.origin.data)) / dot_prod,
- torch.empty_like(dot_prod),
- )
- res_point = self.point_at(res_lambda)
- return res_lambda, res_point
- def _fit_line_ols_2d(points: Tensor) -> ParametrizedLine:
- x = points[..., 0]
- y = points[..., 1]
- x_mean = x.mean(dim=-1, keepdim=True)
- y_mean = y.mean(dim=-1, keepdim=True)
- dx = x - x_mean
- dy = y - y_mean
- denom = (dx * dx).sum(dim=-1, keepdim=True) # (B, 1)
- slope = torch.where(denom > 1e-8, (dx * dy).sum(dim=-1, keepdim=True) / denom, torch.zeros_like(denom))
- # For vertical lines, fallback to [0,1] direction
- direction = torch.where(
- denom > 1e-8,
- torch.cat([torch.ones_like(slope), slope], dim=-1),
- torch.tensor([0.0, 1.0], device=points.device).expand(points.shape[0], 2),
- )
- direction = direction / direction.norm(dim=-1, keepdim=True)
- origin = torch.cat([x_mean, y_mean], dim=-1)
- return ParametrizedLine(origin, direction)
- def _fit_line_weighted_ols_2d(points: Tensor, weights: Tensor) -> ParametrizedLine:
- x = points[..., 0] # (B, N)
- y = points[..., 1] # (B, N)
- w_sum = weights.sum(dim=-1, keepdim=True) # (B, 1)
- x_mean = (weights * x).sum(dim=-1, keepdim=True) / w_sum # (B, 1)
- y_mean = (weights * y).sum(dim=-1, keepdim=True) / w_sum # (B, 1)
- dx = x - x_mean # (B, N)
- dy = y - y_mean # (B, N)
- weighted_dx2 = weights * dx * dx
- weighted_dxdy = weights * dx * dy
- denom = weighted_dx2.sum(dim=-1, keepdim=True) # (B, 1)
- slope = weighted_dxdy.sum(dim=-1, keepdim=True) / denom # (B, 1)
- # Replace NaNs or infs from division by zero
- slope = torch.where(torch.isfinite(slope), slope, torch.zeros_like(slope))
- # direction = normalize([1, slope]) or [0,1] if vertical
- is_vertical = denom <= 1e-8
- direction = torch.cat([torch.ones_like(slope), slope], dim=-1) # (B, 2)
- replacement = torch.tensor([0.0, 1.0], device=points.device, dtype=points.dtype)
- direction[is_vertical.squeeze(-1)] = replacement
- direction = direction / direction.norm(dim=-1, keepdim=True)
- origin = torch.cat([x_mean, y_mean], dim=-1)
- return ParametrizedLine(origin, direction)
- def fit_line(points: Tensor, weights: Optional[Tensor] = None) -> ParametrizedLine:
- """Fit a line from a set of points.
- Args:
- points: tensor containing a batch of sets of n-dimensional points. The expected
- shape of the tensor is :math:`(B, N, D)`.
- weights: weights to use to solve the equations system. The expected
- shape of the tensor is :math:`(B, N)`.
- Return:
- A tensor containing the direction of the fitted line of shape :math:`(B, D)`.
- Example:
- >>> points = torch.rand(2, 10, 3)
- >>> weights = torch.ones(2, 10)
- >>> line = fit_line(points, weights)
- >>> line.direction.shape
- torch.Size([2, 3])
- """
- KORNIA_CHECK_IS_TENSOR(points, "points must be a tensor")
- KORNIA_CHECK_SHAPE(points, ["B", "N", "D"])
- _B, _N, D = points.shape
- # Fast path: use OLS for unweighted 2D case
- if D == 2:
- if weights is not None:
- KORNIA_CHECK_IS_TENSOR(weights, "weights must be a tensor")
- KORNIA_CHECK_SHAPE(weights, ["B", "N"])
- KORNIA_CHECK(points.shape[0] == weights.shape[0])
- return _fit_line_weighted_ols_2d(points, weights)
- else:
- return _fit_line_ols_2d(points)
- mean = points.mean(-2, True)
- A = points - mean
- if weights is not None:
- KORNIA_CHECK_IS_TENSOR(weights, "weights must be a tensor")
- KORNIA_CHECK_SHAPE(weights, ["B", "N"])
- KORNIA_CHECK(points.shape[0] == weights.shape[0])
- A = A.transpose(-2, -1) @ torch.diag_embed(weights) @ A
- else:
- A = A.transpose(-2, -1) @ A
- # NOTE: not optimal for 2d points, but for now works for other dimensions
- _, _, V = _torch_svd_cast(A)
- V = V.transpose(-2, -1)
- # the first left eigenvector is the direction on the fitted line
- direction = V[..., 0, :] # BxD
- origin = mean[..., 0, :] # BxD
- return ParametrizedLine(origin, direction)
|