distort.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Optional
  18. import torch
  19. from kornia.core import cos, ones_like, sin, stack, zeros_like
  20. # Based on https://github.com/opencv/opencv/blob/master/modules/calib3d/src/distortion_model.hpp#L75
  21. def tilt_projection(taux: torch.Tensor, tauy: torch.Tensor, return_inverse: bool = False) -> torch.Tensor:
  22. r"""Estimate the tilt projection matrix or the inverse tilt projection matrix.
  23. Args:
  24. taux: Rotation angle in radians around the :math:`x`-axis with shape :math:`(*, 1)`.
  25. tauy: Rotation angle in radians around the :math:`y`-axis with shape :math:`(*, 1)`.
  26. return_inverse: False to obtain the tilt projection matrix. True for the inverse matrix.
  27. Returns:
  28. torch.Tensor: Inverse tilt projection matrix with shape :math:`(*, 3, 3)`.
  29. """
  30. if taux.shape != tauy.shape:
  31. raise ValueError(f"Shape of taux {taux.shape} and tauy {tauy.shape} do not match.")
  32. ndim: int = taux.dim()
  33. taux = taux.reshape(-1)
  34. tauy = tauy.reshape(-1)
  35. cTx = cos(taux)
  36. sTx = sin(taux)
  37. cTy = cos(tauy)
  38. sTy = sin(tauy)
  39. zero = zeros_like(cTx)
  40. one = ones_like(cTx)
  41. Rx = stack([one, zero, zero, zero, cTx, sTx, zero, -sTx, cTx], -1).reshape(-1, 3, 3)
  42. Ry = stack([cTy, zero, -sTy, zero, one, zero, sTy, zero, cTy], -1).reshape(-1, 3, 3)
  43. R = Ry @ Rx
  44. if return_inverse:
  45. invR22 = 1 / R[..., 2, 2]
  46. invPz = stack(
  47. [invR22, zero, R[..., 0, 2] * invR22, zero, invR22, R[..., 1, 2] * invR22, zero, zero, one], -1
  48. ).reshape(-1, 3, 3)
  49. inv_tilt = R.transpose(-1, -2) @ invPz
  50. if ndim == 0:
  51. inv_tilt = torch.squeeze(inv_tilt)
  52. return inv_tilt
  53. Pz = stack([R[..., 2, 2], zero, -R[..., 0, 2], zero, R[..., 2, 2], -R[..., 1, 2], zero, zero, one], -1).reshape(
  54. -1, 3, 3
  55. )
  56. tilt = Pz @ R.transpose(-1, -2)
  57. if ndim == 0:
  58. tilt = torch.squeeze(tilt)
  59. return tilt
  60. def distort_points(
  61. points: torch.Tensor, K: torch.Tensor, dist: torch.Tensor, new_K: Optional[torch.Tensor] = None
  62. ) -> torch.Tensor:
  63. r"""Distortion of a set of 2D points based on the lens distortion model.
  64. Radial :math:`(k_1, k_2, k_3, k_4, k_4, k_6)`,
  65. tangential :math:`(p_1, p_2)`, thin prism :math:`(s_1, s_2, s_3, s_4)`, and tilt :math:`(\tau_x, \tau_y)`
  66. distortion models are considered in this function.
  67. Args:
  68. points: Input image points with shape :math:`(*, N, 2)`.
  69. K: Intrinsic camera matrix with shape :math:`(*, 3, 3)`.
  70. dist: Distortion coefficients
  71. :math:`(k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])`. This is
  72. a vector with 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)`.
  73. new_K: Intrinsic camera matrix of the distorted image. By default, it is the same as K but you may additionally
  74. scale and shift the result by using a different matrix. Shape: :math:`(*, 3, 3)`. Default: None.
  75. Returns:
  76. Undistorted 2D points with shape :math:`(*, N, 2)`.
  77. Example:
  78. >>> points = torch.rand(1, 1, 2)
  79. >>> K = torch.eye(3)[None]
  80. >>> dist_coeff = torch.rand(1, 4)
  81. >>> points_dist = distort_points(points, K, dist_coeff)
  82. """
  83. if points.dim() < 2 and points.shape[-1] != 2:
  84. raise ValueError(f"points shape is invalid. Got {points.shape}.")
  85. if K.shape[-2:] != (3, 3):
  86. raise ValueError(f"K matrix shape is invalid. Got {K.shape}.")
  87. if new_K is None:
  88. new_K = K
  89. elif new_K.shape[-2:] != (3, 3):
  90. raise ValueError(f"new_K matrix shape is invalid. Got {new_K.shape}.")
  91. if dist.shape[-1] not in [4, 5, 8, 12, 14]:
  92. raise ValueError(f"Invalid number of distortion coefficients. Got {dist.shape[-1]}")
  93. # Adding zeros to obtain vector with 14 coeffs.
  94. if dist.shape[-1] < 14:
  95. dist = torch.nn.functional.pad(dist, [0, 14 - dist.shape[-1]])
  96. # Convert 2D points from pixels to normalized camera coordinates
  97. new_cx: torch.Tensor = new_K[..., 0:1, 2] # princial point in x (Bx1)
  98. new_cy: torch.Tensor = new_K[..., 1:2, 2] # princial point in y (Bx1)
  99. new_fx: torch.Tensor = new_K[..., 0:1, 0] # focal in x (Bx1)
  100. new_fy: torch.Tensor = new_K[..., 1:2, 1] # focal in y (Bx1)
  101. # This is equivalent to K^-1 [u,v,1]^T
  102. x: torch.Tensor = (points[..., 0] - new_cx) / new_fx # (BxN - Bx1)/Bx1 -> BxN or (N,)
  103. y: torch.Tensor = (points[..., 1] - new_cy) / new_fy # (BxN - Bx1)/Bx1 -> BxN or (N,)
  104. # Distort points
  105. r2 = x * x + y * y
  106. rad_poly = (1 + dist[..., 0:1] * r2 + dist[..., 1:2] * r2 * r2 + dist[..., 4:5] * r2**3) / (
  107. 1 + dist[..., 5:6] * r2 + dist[..., 6:7] * r2 * r2 + dist[..., 7:8] * r2**3
  108. )
  109. xd = (
  110. x * rad_poly
  111. + 2 * dist[..., 2:3] * x * y
  112. + dist[..., 3:4] * (r2 + 2 * x * x)
  113. + dist[..., 8:9] * r2
  114. + dist[..., 9:10] * r2 * r2
  115. )
  116. yd = (
  117. y * rad_poly
  118. + dist[..., 2:3] * (r2 + 2 * y * y)
  119. + 2 * dist[..., 3:4] * x * y
  120. + dist[..., 10:11] * r2
  121. + dist[..., 11:12] * r2 * r2
  122. )
  123. # Compensate for tilt distortion
  124. if torch.any(dist[..., 12] != 0) or torch.any(dist[..., 13] != 0):
  125. tilt = tilt_projection(dist[..., 12], dist[..., 13])
  126. # Transposed untilt points (instead of [x,y,1]^T, we obtain [x,y,1])
  127. points_untilt = stack([xd, yd, ones_like(xd)], -1) @ tilt.transpose(-2, -1)
  128. xd = points_untilt[..., 0] / points_untilt[..., 2]
  129. yd = points_untilt[..., 1] / points_untilt[..., 2]
  130. # Convert points from normalized camera coordinates to pixel coordinates
  131. cx: torch.Tensor = K[..., 0:1, 2] # princial point in x (Bx1)
  132. cy: torch.Tensor = K[..., 1:2, 2] # princial point in y (Bx1)
  133. fx: torch.Tensor = K[..., 0:1, 0] # focal in x (Bx1)
  134. fy: torch.Tensor = K[..., 1:2, 1] # focal in y (Bx1)
  135. x = fx * xd + cx
  136. y = fy * yd + cy
  137. return stack([x, y], -1)