# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # kornia.geometry.line module inspired by Eigen::geometry::ParametrizedLine # https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/ParametrizedLine.h from typing import Iterator, Optional, Tuple, Union import torch from kornia.core import Module, Parameter, Tensor, normalize, where from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE from kornia.geometry.linalg import batched_dot_product from kornia.geometry.plane import Hyperplane from kornia.utils.helpers import _torch_svd_cast __all__ = ["ParametrizedLine", "fit_line"] class ParametrizedLine(Module): """Class that describes a parametrize line. A parametrized line is defined by an origin point :math:`o` and a unit direction vector :math:`d` such that the line corresponds to the set .. math:: l(t) = o + t * d """ def __init__(self, origin: Tensor, direction: Tensor) -> None: """Initialize a parametrized line of direction and origin. Args: origin: any point on the line of any dimension. direction: the normalized vector direction of any dimension. Example: >>> o = torch.tensor([0.0, 0.0]) >>> d = torch.tensor([1.0, 1.0]) >>> l = ParametrizedLine(o, d) """ super().__init__() self._origin = Parameter(origin) self._direction = Parameter(direction) def __str__(self) -> str: return f"Origin: {self.origin}\nDirection: {self.direction}" def __repr__(self) -> str: return str(self) def __getitem__(self, idx: int) -> Tensor: return self.origin if idx == 0 else self.direction def __iter__(self) -> Iterator[Tensor]: yield from (self.origin, self.direction) @property def origin(self) -> Tensor: """Return the line origin point.""" return self._origin @property def direction(self) -> Tensor: """Return the line direction vector.""" return self._direction def dim(self) -> int: """Return the dimension in which the line holds.""" return self.direction.shape[-1] @classmethod def through(cls, p0: Tensor, p1: Tensor) -> "ParametrizedLine": """Construct a parametrized line going from a point :math:`p0` to :math:`p1`. Args: p0: tensor with first point :math:`(B, D)` where `D` is the point dimension. p1: tensor with second point :math:`(B, D)` where `D` is the point dimension. Example: >>> p0 = torch.tensor([0.0, 0.0]) >>> p1 = torch.tensor([1.0, 1.0]) >>> l = ParametrizedLine.through(p0, p1) """ return ParametrizedLine(p0, normalize((p1 - p0), p=2, dim=-1)) def point_at(self, t: Union[float, Tensor]) -> Tensor: """Get the point at :math:`t` along this line. Args: t: step along the line. Return: tensor with the point. Example: >>> p0 = torch.tensor([0.0, 0.0]) >>> p1 = torch.tensor([1.0, 1.0]) >>> l = ParametrizedLine.through(p0, p1) >>> p2 = l.point_at(0.1) """ return self.origin + self.direction * t def projection(self, point: Tensor) -> Tensor: """Return the projection of a point onto the line. Args: point: the point to be projected. """ return self.origin + (self.direction @ (point - self.origin)) * self.direction def squared_distance(self, point: Tensor) -> Tensor: """Return the squared distance of a point to its projection onte the line. Args: point: the point to calculate the distance onto the line. """ d = point - self.origin proj = torch.sum(d * self.direction, dim=-1) sq_norm_d = torch.sum(d * d, dim=-1) return sq_norm_d - proj * proj def distance(self, point: Tensor) -> Tensor: """Return the distance of a point to its projections onto the line. Args: point: the point to calculate the distance into the line. """ return self.squared_distance(point).sqrt() # TODO(edgar) implement the following: # - intersection # - intersection_parameter # - intersection_point # TODO: add tests, and possibly return a mask def intersect(self, plane: Hyperplane, eps: float = 1e-6) -> Tuple[Tensor, Tensor]: """Return the intersection point between the line and a given plane. Args: plane: the plane to compute the intersection point. eps: epsilon for numerical stability. Return: - the lambda value used to compute the look at point. - the intersected point. """ dot_prod = batched_dot_product(plane.normal.data, self.direction.data) dot_prod_mask = dot_prod.abs() >= eps # TODO: add check for dot product res_lambda = where( dot_prod_mask, -(plane.offset + batched_dot_product(plane.normal.data, self.origin.data)) / dot_prod, torch.empty_like(dot_prod), ) res_point = self.point_at(res_lambda) return res_lambda, res_point def _fit_line_ols_2d(points: Tensor) -> ParametrizedLine: x = points[..., 0] y = points[..., 1] x_mean = x.mean(dim=-1, keepdim=True) y_mean = y.mean(dim=-1, keepdim=True) dx = x - x_mean dy = y - y_mean denom = (dx * dx).sum(dim=-1, keepdim=True) # (B, 1) slope = torch.where(denom > 1e-8, (dx * dy).sum(dim=-1, keepdim=True) / denom, torch.zeros_like(denom)) # For vertical lines, fallback to [0,1] direction direction = torch.where( denom > 1e-8, torch.cat([torch.ones_like(slope), slope], dim=-1), torch.tensor([0.0, 1.0], device=points.device).expand(points.shape[0], 2), ) direction = direction / direction.norm(dim=-1, keepdim=True) origin = torch.cat([x_mean, y_mean], dim=-1) return ParametrizedLine(origin, direction) def _fit_line_weighted_ols_2d(points: Tensor, weights: Tensor) -> ParametrizedLine: x = points[..., 0] # (B, N) y = points[..., 1] # (B, N) w_sum = weights.sum(dim=-1, keepdim=True) # (B, 1) x_mean = (weights * x).sum(dim=-1, keepdim=True) / w_sum # (B, 1) y_mean = (weights * y).sum(dim=-1, keepdim=True) / w_sum # (B, 1) dx = x - x_mean # (B, N) dy = y - y_mean # (B, N) weighted_dx2 = weights * dx * dx weighted_dxdy = weights * dx * dy denom = weighted_dx2.sum(dim=-1, keepdim=True) # (B, 1) slope = weighted_dxdy.sum(dim=-1, keepdim=True) / denom # (B, 1) # Replace NaNs or infs from division by zero slope = torch.where(torch.isfinite(slope), slope, torch.zeros_like(slope)) # direction = normalize([1, slope]) or [0,1] if vertical is_vertical = denom <= 1e-8 direction = torch.cat([torch.ones_like(slope), slope], dim=-1) # (B, 2) replacement = torch.tensor([0.0, 1.0], device=points.device, dtype=points.dtype) direction[is_vertical.squeeze(-1)] = replacement direction = direction / direction.norm(dim=-1, keepdim=True) origin = torch.cat([x_mean, y_mean], dim=-1) return ParametrizedLine(origin, direction) def fit_line(points: Tensor, weights: Optional[Tensor] = None) -> ParametrizedLine: """Fit a line from a set of points. Args: points: tensor containing a batch of sets of n-dimensional points. The expected shape of the tensor is :math:`(B, N, D)`. weights: weights to use to solve the equations system. The expected shape of the tensor is :math:`(B, N)`. Return: A tensor containing the direction of the fitted line of shape :math:`(B, D)`. Example: >>> points = torch.rand(2, 10, 3) >>> weights = torch.ones(2, 10) >>> line = fit_line(points, weights) >>> line.direction.shape torch.Size([2, 3]) """ KORNIA_CHECK_IS_TENSOR(points, "points must be a tensor") KORNIA_CHECK_SHAPE(points, ["B", "N", "D"]) _B, _N, D = points.shape # Fast path: use OLS for unweighted 2D case if D == 2: if weights is not None: KORNIA_CHECK_IS_TENSOR(weights, "weights must be a tensor") KORNIA_CHECK_SHAPE(weights, ["B", "N"]) KORNIA_CHECK(points.shape[0] == weights.shape[0]) return _fit_line_weighted_ols_2d(points, weights) else: return _fit_line_ols_2d(points) mean = points.mean(-2, True) A = points - mean if weights is not None: KORNIA_CHECK_IS_TENSOR(weights, "weights must be a tensor") KORNIA_CHECK_SHAPE(weights, ["B", "N"]) KORNIA_CHECK(points.shape[0] == weights.shape[0]) A = A.transpose(-2, -1) @ torch.diag_embed(weights) @ A else: A = A.transpose(-2, -1) @ A # NOTE: not optimal for 2d points, but for now works for other dimensions _, _, V = _torch_svd_cast(A) V = V.transpose(-2, -1) # the first left eigenvector is the direction on the fitted line direction = V[..., 0, :] # BxD origin = mean[..., 0, :] # BxD return ParametrizedLine(origin, direction)