undistort.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Optional
  19. import torch
  20. from kornia.core import stack
  21. from kornia.core.check import KORNIA_CHECK_SHAPE
  22. from kornia.geometry.linalg import transform_points
  23. from kornia.geometry.transform import remap
  24. from kornia.utils import create_meshgrid
  25. from .distort import distort_points, tilt_projection
  26. # Based on https://github.com/opencv/opencv/blob/master/modules/calib3d/src/undistort.dispatch.cpp#L384
  27. def undistort_points(
  28. points: torch.Tensor, K: torch.Tensor, dist: torch.Tensor, new_K: Optional[torch.Tensor] = None, num_iters: int = 5
  29. ) -> torch.Tensor:
  30. r"""Compensate for lens distortion a set of 2D image points.
  31. Radial :math:`(k_1, k_2, k_3, k_4, k_5, k_6)`,
  32. tangential :math:`(p_1, p_2)`, thin prism :math:`(s_1, s_2, s_3, s_4)`, and tilt :math:`(\tau_x, \tau_y)`
  33. distortion models are considered in this function.
  34. Args:
  35. points: Input image points with shape :math:`(*, N, 2)`.
  36. K: Intrinsic camera matrix with shape :math:`(*, 3, 3)`.
  37. dist: Distortion coefficients
  38. :math:`(k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])`. This is
  39. a vector with 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)`.
  40. new_K: Intrinsic camera matrix of the distorted image. By default, it is the same as K but you may additionally
  41. scale and shift the result by using a different matrix. Shape: :math:`(*, 3, 3)`. Default: None.
  42. num_iters: Number of undistortion iterations. Default: 5.
  43. Returns:
  44. Undistorted 2D points with shape :math:`(*, N, 2)`.
  45. Example:
  46. >>> _ = torch.manual_seed(0)
  47. >>> x = torch.rand(1, 4, 2)
  48. >>> K = torch.eye(3)[None]
  49. >>> dist = torch.rand(1, 4)
  50. >>> undistort_points(x, K, dist)
  51. tensor([[[-0.1513, -0.1165],
  52. [ 0.0711, 0.1100],
  53. [-0.0697, 0.0228],
  54. [-0.1843, -0.1606]]])
  55. """
  56. KORNIA_CHECK_SHAPE(points, ["*", "N", "2"])
  57. KORNIA_CHECK_SHAPE(K, ["*", "3", "3"])
  58. if points.dim() < 2 and points.shape[-1] != 2:
  59. raise ValueError(f"points shape is invalid. Got {points.shape}.")
  60. if new_K is None:
  61. new_K = K
  62. else:
  63. KORNIA_CHECK_SHAPE(new_K, ["*", "3", "3"])
  64. if dist.shape[-1] not in [4, 5, 8, 12, 14]:
  65. raise ValueError(f"Invalid number of distortion coefficients. Got {dist.shape[-1]}")
  66. # Adding zeros to obtain vector with 14 coeffs.
  67. if dist.shape[-1] < 14:
  68. dist = torch.nn.functional.pad(dist, [0, 14 - dist.shape[-1]])
  69. # Convert 2D points from pixels to normalized camera coordinates
  70. cx: torch.Tensor = K[..., 0:1, 2] # princial point in x (Bx1)
  71. cy: torch.Tensor = K[..., 1:2, 2] # princial point in y (Bx1)
  72. fx: torch.Tensor = K[..., 0:1, 0] # focal in x (Bx1)
  73. fy: torch.Tensor = K[..., 1:2, 1] # focal in y (Bx1)
  74. # This is equivalent to K^-1 [u,v,1]^T
  75. x: torch.Tensor = (points[..., 0] - cx) / fx # (BxN - Bx1)/Bx1 -> BxN
  76. y: torch.Tensor = (points[..., 1] - cy) / fy # (BxN - Bx1)/Bx1 -> BxN
  77. # Compensate for tilt distortion
  78. if torch.any(dist[..., 12] != 0) or torch.any(dist[..., 13] != 0):
  79. inv_tilt = tilt_projection(dist[..., 12], dist[..., 13], True)
  80. # Transposed untilt points (instead of [x,y,1]^T, we obtain [x,y,1])
  81. x, y = transform_points(inv_tilt, stack([x, y], dim=-1)).unbind(-1)
  82. # Iteratively undistort points
  83. x0, y0 = x, y
  84. for _ in range(num_iters):
  85. r2 = x * x + y * y
  86. inv_rad_poly = (1 + dist[..., 5:6] * r2 + dist[..., 6:7] * r2 * r2 + dist[..., 7:8] * r2**3) / (
  87. 1 + dist[..., 0:1] * r2 + dist[..., 1:2] * r2 * r2 + dist[..., 4:5] * r2**3
  88. )
  89. deltaX = (
  90. 2 * dist[..., 2:3] * x * y
  91. + dist[..., 3:4] * (r2 + 2 * x * x)
  92. + dist[..., 8:9] * r2
  93. + dist[..., 9:10] * r2 * r2
  94. )
  95. deltaY = (
  96. dist[..., 2:3] * (r2 + 2 * y * y)
  97. + 2 * dist[..., 3:4] * x * y
  98. + dist[..., 10:11] * r2
  99. + dist[..., 11:12] * r2 * r2
  100. )
  101. x = (x0 - deltaX) * inv_rad_poly
  102. y = (y0 - deltaY) * inv_rad_poly
  103. # Convert points from normalized camera coordinates to pixel coordinates
  104. new_cx: torch.Tensor = new_K[..., 0:1, 2] # princial point in x (Bx1)
  105. new_cy: torch.Tensor = new_K[..., 1:2, 2] # princial point in y (Bx1)
  106. new_fx: torch.Tensor = new_K[..., 0:1, 0] # focal in x (Bx1)
  107. new_fy: torch.Tensor = new_K[..., 1:2, 1] # focal in y (Bx1)
  108. x = new_fx * x + new_cx
  109. y = new_fy * y + new_cy
  110. return stack([x, y], -1)
  111. # Based on https://github.com/opencv/opencv/blob/master/modules/calib3d/src/undistort.dispatch.cpp#L287
  112. def undistort_image(image: torch.Tensor, K: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
  113. r"""Compensate an image for lens distortion.
  114. Radial :math:`(k_1, k_2, k_3, k_4, k_4, k_6)`,
  115. tangential :math:`(p_1, p_2)`, thin prism :math:`(s_1, s_2, s_3, s_4)`, and tilt :math:`(\tau_x, \tau_y)`
  116. distortion models are considered in this function.
  117. Args:
  118. image: Input image with shape :math:`(*, C, H, W)`.
  119. K: Intrinsic camera matrix with shape :math:`(*, 3, 3)`.
  120. dist: Distortion coefficients
  121. :math:`(k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])`. This is
  122. a vector with 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)`.
  123. Returns:
  124. Undistorted image with shape :math:`(*, C, H, W)`.
  125. Example:
  126. >>> img = torch.rand(1, 3, 5, 5)
  127. >>> K = torch.eye(3)[None]
  128. >>> dist_coeff = torch.rand(1, 4)
  129. >>> out = undistort_image(img, K, dist_coeff)
  130. >>> out.shape
  131. torch.Size([1, 3, 5, 5])
  132. """
  133. if len(image.shape) < 3:
  134. raise ValueError(f"Image shape is invalid. Got: {image.shape}.")
  135. if K.shape[-2:] != (3, 3):
  136. raise ValueError(f"K matrix shape is invalid. Got {K.shape}.")
  137. if dist.shape[-1] not in [4, 5, 8, 12, 14]:
  138. raise ValueError(f"Invalid number of distortion coefficients. Got {dist.shape[-1]}.")
  139. if not image.is_floating_point():
  140. raise ValueError(f"Invalid input image data type. Input should be float. Got {image.dtype}.")
  141. if image.shape[:-3] != K.shape[:-2] or image.shape[:-3] != dist.shape[:-1]:
  142. # Input with image shape (1, C, H, W), K shape (3, 3), dist shape (4)
  143. # allowed to avoid a breaking change.
  144. if not all((image.shape[:-3] == (1,), K.shape[:-2] == (), dist.shape[:-1] == ())):
  145. raise ValueError(
  146. "Input shape is invalid. Input batch dimensions should match. "
  147. f"Got {image.shape[:-3]}, {K.shape[:-2]}, {dist.shape[:-1]}."
  148. )
  149. channels, rows, cols = image.shape[-3:]
  150. B = image.numel() // (channels * rows * cols)
  151. # Create point coordinates for each pixel of the image
  152. xy_grid: torch.Tensor = create_meshgrid(rows, cols, False, image.device, image.dtype)
  153. pts = xy_grid.reshape(-1, 2) # (rows*cols)x2 matrix of pixel coordinates
  154. # Distort points and define maps
  155. ptsd: torch.Tensor = distort_points(pts, K, dist) # Bx(rows*cols)x2
  156. mapx: torch.Tensor = ptsd[..., 0].reshape(B, rows, cols) # B x rows x cols, float
  157. mapy: torch.Tensor = ptsd[..., 1].reshape(B, rows, cols) # B x rows x cols, float
  158. # Remap image to undistort
  159. out = remap(image.reshape(B, channels, rows, cols), mapx, mapy, align_corners=True)
  160. return out.view_as(image)