numeric.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Module containing numerical functionalities for SfM."""
  18. import torch
  19. from kornia.core import stack, zeros_like
  20. def cross_product_matrix(x: torch.Tensor) -> torch.Tensor:
  21. r"""Return the cross_product_matrix symmetric matrix of a vector.
  22. Args:
  23. x: The input vector to construct the matrix in the shape :math:`(*, 3)`.
  24. Returns:
  25. The constructed cross_product_matrix symmetric matrix with shape :math:`(*, 3, 3)`.
  26. """
  27. if not x.shape[-1] == 3:
  28. raise AssertionError(x.shape)
  29. # get vector compononens
  30. x0 = x[..., 0]
  31. x1 = x[..., 1]
  32. x2 = x[..., 2]
  33. # construct the matrix, reshape to 3x3 and return
  34. zeros = zeros_like(x0)
  35. cross_product_matrix_flat = stack([zeros, -x2, x1, x2, zeros, -x0, -x1, x0, zeros], dim=-1)
  36. shape_ = x.shape[:-1] + (3, 3)
  37. return cross_product_matrix_flat.view(*shape_)
  38. def matrix_cofactor_tensor(matrix: torch.Tensor) -> torch.Tensor:
  39. """Cofactor matrix, refer to the numpy doc.
  40. Args:
  41. matrix: The input matrix in the shape :math:`(*, 3, 3)`.
  42. """
  43. det = torch.det(matrix)
  44. singular_mask = det != 0
  45. if singular_mask.sum() != 0:
  46. # B, 3, 3
  47. cofactor = torch.linalg.inv(matrix[singular_mask]).transpose(-2, -1) * det[:, None, None]
  48. # return cofactor matrix of the given matrix
  49. returned_cofactor = torch.zeros_like(matrix)
  50. returned_cofactor[singular_mask] = cofactor
  51. return returned_cofactor
  52. else:
  53. raise Exception("all singular matrices")