plane.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # kornia.geometry.plane module inspired by Eigen::geometry::Hyperplane
  18. # https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Hyperplane.h
  19. from typing import Optional
  20. import torch
  21. from kornia.core import Module, Tensor, stack, where
  22. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE, KORNIA_CHECK_TYPE
  23. from kornia.core.tensor_wrapper import unwrap, wrap # type: ignore[attr-defined]
  24. from kornia.geometry.linalg import batched_dot_product
  25. from kornia.geometry.vector import Scalar, Vector3
  26. from kornia.utils.helpers import _torch_svd_cast
  27. __all__ = ["Hyperplane", "fit_plane"]
  28. def normalized(v: Tensor, eps: float = 1e-6) -> Tensor:
  29. norm_sq = (v * v).sum(dim=-1, keepdim=True) + eps
  30. return v * norm_sq.rsqrt()
  31. class Hyperplane(Module):
  32. def __init__(self, n: Vector3, d: Scalar) -> None:
  33. super().__init__()
  34. KORNIA_CHECK_TYPE(n, Vector3)
  35. KORNIA_CHECK_TYPE(d, Scalar)
  36. # TODO: fix checkers
  37. # KORNIA_CHECK_SHAPE(n, ["B", "*"])
  38. # KORNIA_CHECK_SHAPE(d, ["B"])
  39. self._n = n
  40. self._d = d
  41. def __str__(self) -> str:
  42. return f"Normal: {self.normal}\nOffset: {self.offset}"
  43. def __repr__(self) -> str:
  44. return str(self)
  45. @property
  46. def normal(self) -> Vector3:
  47. return self._n
  48. @property
  49. def offset(self) -> Scalar:
  50. return self._d
  51. def abs_distance(self, p: Vector3) -> Scalar:
  52. return Scalar(self.signed_distance(p).abs())
  53. # https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Hyperplane.h#L145
  54. # TODO: tests
  55. def signed_distance(self, p: Vector3) -> Scalar:
  56. KORNIA_CHECK(isinstance(p, (Vector3, Tensor)))
  57. return self.normal.dot(p) + self.offset
  58. # https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Hyperplane.h#L154
  59. # TODO: tests
  60. def projection(self, p: Vector3) -> Vector3:
  61. dist = self.signed_distance(p)
  62. if len(dist.shape) != len(self.normal):
  63. # non batched plane project a batch of points
  64. dist = dist[..., None] # Nx1
  65. # TODO: TypeError: bad operand type for unary -: 'Scalar'
  66. return p - dist.data * self.normal
  67. # TODO: make that Vector can subtract Scalar
  68. # return p - self.signed_distance(p) * self.normal
  69. @classmethod
  70. def from_vector(self, n: Vector3, e: Vector3) -> "Hyperplane":
  71. normal: Vector3 = n
  72. offset = -normal.dot(e)
  73. return Hyperplane(normal, Scalar(offset))
  74. @classmethod
  75. def through(cls, p0: Tensor, p1: Tensor, p2: Optional[Tensor] = None) -> "Hyperplane":
  76. # 2d case
  77. if p2 is None:
  78. # TODO: improve tests
  79. KORNIA_CHECK_SHAPE(p0, ["*", "2"])
  80. KORNIA_CHECK(p0.shape == p1.shape)
  81. # TODO: implement `.unitOrthonormal`
  82. normal2d = normalized(p1 - p0)
  83. offset2d = -batched_dot_product(p0, normal2d)
  84. return Hyperplane(wrap(normal2d, Vector3), wrap(offset2d, Scalar))
  85. # 3d case
  86. KORNIA_CHECK_SHAPE(p0, ["*", "3"])
  87. KORNIA_CHECK(p0.shape == p1.shape)
  88. KORNIA_CHECK(p1.shape == p2.shape)
  89. v0, v1 = (p2 - p0), (p1 - p0)
  90. normal = torch.linalg.cross(v0, v1, dim=-1)
  91. norm = normal.norm(-1)
  92. # https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Hyperplane.h#L108
  93. def compute_normal_svd(v0: Tensor, v1: Tensor) -> "Vector3":
  94. # NOTE: for reason TensorWrapper does not stack well
  95. m = stack((unwrap(v0), unwrap(v1)), -2) # Bx2x3
  96. _, _, V = _torch_svd_cast(m) # kornia solution lies in the last row
  97. return wrap(V[..., :, -1], Vector3) # Bx3
  98. normal_mask = norm <= v0.norm(-1) * v1.norm(-1) * 1e-6
  99. normal = where(normal_mask, compute_normal_svd(v0, v1).data, normal / (norm + 1e-6))
  100. offset = -batched_dot_product(p0, normal)
  101. return Hyperplane(wrap(normal, Vector3), wrap(offset, Scalar))
  102. # TODO: factor to avoid duplicated from line.py
  103. # https://github.com/strasdat/Sophus/blob/23.04-beta/cpp/sophus/geometry/fit_plane.h
  104. def fit_plane(points: Vector3) -> Hyperplane:
  105. """Fit a plane from a set of points using SVD.
  106. Args:
  107. points: tensor containing a batch of sets of n-dimensional points. The expected
  108. shape of the tensor is :math:`(N, D)`.
  109. Return:
  110. The computed hyperplane object.
  111. """
  112. # TODO: fix to support more type check here
  113. # KORNIA_CHECK_SHAPE(points, ["N", "D"])
  114. if points.shape[-1] != 3:
  115. raise TypeError("vector must be (*, 3)")
  116. mean = points.mean(-2, True)
  117. points_centered = points - mean
  118. # NOTE: not optimal for 2d points, but for now works for other dimensions
  119. _, _, V = _torch_svd_cast(points_centered)
  120. # the first left eigenvector is the direction on the fited line
  121. direction = V[..., :, -1] # BxD
  122. origin = mean[..., 0, :] # BxD
  123. return Hyperplane.from_vector(Vector3(direction), Vector3(origin))