imgwarp.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Optional
  19. import torch
  20. import torch.nn.functional as F
  21. from kornia.core import Tensor, concatenate, ones, ones_like, stack, tan, tensor, zeros
  22. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
  23. from kornia.geometry.conversions import (
  24. angle_to_rotation_matrix,
  25. axis_angle_to_rotation_matrix,
  26. convert_affinematrix_to_homography,
  27. convert_affinematrix_to_homography3d,
  28. deg2rad,
  29. normalize_homography,
  30. normalize_homography3d,
  31. normalize_pixel_coordinates,
  32. )
  33. from kornia.geometry.linalg import transform_points
  34. from kornia.utils import create_meshgrid, create_meshgrid3d, eye_like
  35. from kornia.utils.helpers import _torch_inverse_cast, _torch_solve_cast
  36. __all__ = [
  37. "get_affine_matrix2d",
  38. "get_affine_matrix3d",
  39. "get_perspective_transform",
  40. "get_perspective_transform3d",
  41. "get_projective_transform",
  42. "get_rotation_matrix2d",
  43. "get_shear_matrix2d",
  44. "get_shear_matrix3d",
  45. "get_translation_matrix2d",
  46. "homography_warp",
  47. "homography_warp3d",
  48. "invert_affine_transform",
  49. "projection_from_Rt",
  50. "remap",
  51. "warp_affine",
  52. "warp_affine3d",
  53. "warp_grid",
  54. "warp_grid3d",
  55. "warp_perspective",
  56. "warp_perspective3d",
  57. ]
  58. def warp_perspective(
  59. src: Tensor,
  60. M: Tensor,
  61. dsize: tuple[int, int],
  62. mode: str = "bilinear",
  63. padding_mode: str = "zeros",
  64. align_corners: bool = True,
  65. fill_value: Optional[Tensor] = None, # needed for jit
  66. ) -> Tensor:
  67. r"""Apply a perspective transformation to an image.
  68. The function warp_perspective transforms the source image using
  69. the specified matrix:
  70. .. math::
  71. \text{dst} (x, y) = \text{src} \left(
  72. \frac{M^{-1}_{11} x + M^{-1}_{12} y + M^{-1}_{13}}{M^{-1}_{31} x + M^{-1}_{32} y + M^{-1}_{33}} ,
  73. \frac{M^{-1}_{21} x + M^{-1}_{22} y + M^{-1}_{23}}{M^{-1}_{31} x + M^{-1}_{32} y + M^{-1}_{33}}
  74. \right )
  75. Args:
  76. src: input image with shape :math:`(B, C, H, W)`.
  77. M: transformation matrix with shape :math:`(B, 3, 3)`.
  78. dsize: size of the output image (height, width).
  79. mode: interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'``.
  80. padding_mode: padding mode for outside grid values ``'zeros'`` | ``'border'`` | ``'reflection'`` | ``'fill'``.
  81. align_corners: interpolation flag.
  82. fill_value: tensor of shape :math:`(3)` that fills the padding area. Only supported for RGB.
  83. Returns:
  84. the warped input image :math:`(B, C, H, W)`.
  85. Example:
  86. >>> img = torch.rand(1, 4, 5, 6)
  87. >>> H = torch.eye(3)[None]
  88. >>> out = warp_perspective(img, H, (4, 2), align_corners=True)
  89. >>> print(out.shape)
  90. torch.Size([1, 4, 4, 2])
  91. .. note::
  92. This function is often used in conjunction with :func:`get_perspective_transform`.
  93. .. note::
  94. See a working example `here <https://kornia.github.io/tutorials/nbs/warp_perspective.html>`_.
  95. """
  96. if not isinstance(src, Tensor):
  97. raise TypeError(f"Input src type is not a Tensor. Got {type(src)}")
  98. if not isinstance(M, Tensor):
  99. raise TypeError(f"Input M type is not a Tensor. Got {type(M)}")
  100. if not len(src.shape) == 4:
  101. raise ValueError(f"Input src must be a BxCxHxW tensor. Got {src.shape}")
  102. if not (len(M.shape) == 3 and M.shape[-2:] == (3, 3)):
  103. raise ValueError(f"Input M must be a Bx3x3 tensor. Got {M.shape}")
  104. # fill padding is only supported for 3 channels because we can't set fill_value default
  105. # to None as this gives jit issues.
  106. if fill_value is None:
  107. fill_value = zeros(3)
  108. if padding_mode == "fill" and fill_value.shape != torch.Size([3]):
  109. raise ValueError(f"Padding_tensor only supported for 3 channels. Got {fill_value.shape}")
  110. B, _, H, W = src.size()
  111. h_out, w_out = dsize
  112. # we normalize the 3x3 transformation matrix and convert to 3x4
  113. dst_norm_trans_src_norm: Tensor = normalize_homography(M, (H, W), (h_out, w_out)) # Bx3x3
  114. src_norm_trans_dst_norm = _torch_inverse_cast(dst_norm_trans_src_norm) # Bx3x3
  115. # this piece of code substitutes F.affine_grid since it does not support 3x3
  116. grid = (
  117. create_meshgrid(h_out, w_out, normalized_coordinates=True, device=src.device)
  118. .to(src.dtype)
  119. .expand(B, h_out, w_out, 2)
  120. )
  121. grid = transform_points(src_norm_trans_dst_norm[:, None, None], grid)
  122. if padding_mode == "fill":
  123. return _fill_and_warp(src, grid, align_corners=align_corners, mode=mode, fill_value=fill_value)
  124. return F.grid_sample(src, grid, align_corners=align_corners, mode=mode, padding_mode=padding_mode)
  125. def warp_affine(
  126. src: Tensor,
  127. M: Tensor,
  128. dsize: tuple[int, int],
  129. mode: str = "bilinear",
  130. padding_mode: str = "zeros",
  131. align_corners: bool = True,
  132. fill_value: Optional[Tensor] = None, # needed for jit
  133. ) -> Tensor:
  134. r"""Apply an affine transformation to a tensor.
  135. .. image:: _static/img/warp_affine.png
  136. The function warp_affine transforms the source tensor using
  137. the specified matrix:
  138. .. math::
  139. \text{dst}(x, y) = \text{src} \left( M_{11} x + M_{12} y + M_{13} ,
  140. M_{21} x + M_{22} y + M_{23} \right )
  141. Args:
  142. src: input tensor of shape :math:`(B, C, H, W)`.
  143. M: affine transformation of shape :math:`(B, 2, 3)`.
  144. dsize: size of the output image (height, width).
  145. mode: interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'``.
  146. padding_mode: padding mode for outside grid values ``'zeros'`` | ``'border'`` | ``'reflection'`` | ``'fill'``.
  147. align_corners : mode for grid_generation.
  148. fill_value: tensor of shape :math:`(3)` that fills the padding area. Only supported for RGB.
  149. Returns:
  150. the warped tensor with shape :math:`(B, C, H, W)`.
  151. .. note::
  152. This function is often used in conjunction with :func:`get_rotation_matrix2d`,
  153. :func:`get_shear_matrix2d`, :func:`get_affine_matrix2d`, :func:`invert_affine_transform`.
  154. .. note::
  155. See a working example `here <https://kornia.github.io/tutorials/nbs/rotate_affine.html>`__.
  156. Example:
  157. >>> img = torch.rand(1, 4, 5, 6)
  158. >>> A = torch.eye(2, 3)[None]
  159. >>> out = warp_affine(img, A, (4, 2), align_corners=True)
  160. >>> print(out.shape)
  161. torch.Size([1, 4, 4, 2])
  162. """
  163. if not isinstance(src, Tensor):
  164. raise TypeError(f"Input src type is not a Tensor. Got {type(src)}")
  165. if not isinstance(M, Tensor):
  166. raise TypeError(f"Input M type is not a Tensor. Got {type(M)}")
  167. if not len(src.shape) == 4:
  168. raise ValueError(f"Input src must be a BxCxHxW tensor. Got {src.shape}")
  169. if not (len(M.shape) == 3 or M.shape[-2:] == (2, 3)):
  170. raise ValueError(f"Input M must be a Bx2x3 tensor. Got {M.shape}")
  171. B, C, H, W = src.size()
  172. # we generate a 3x3 transformation matrix from 2x3 affine
  173. M_3x3: Tensor = convert_affinematrix_to_homography(M)
  174. dst_norm_trans_src_norm: Tensor = normalize_homography(M_3x3, (H, W), dsize)
  175. # src_norm_trans_dst_norm = torch.inverse(dst_norm_trans_src_norm)
  176. src_norm_trans_dst_norm = _torch_inverse_cast(dst_norm_trans_src_norm)
  177. grid = F.affine_grid(src_norm_trans_dst_norm[:, :2, :], [B, C, dsize[0], dsize[1]], align_corners=align_corners)
  178. if padding_mode == "fill":
  179. if fill_value is None:
  180. fill_value = zeros(3)
  181. return _fill_and_warp(src, grid, align_corners=align_corners, mode=mode, fill_value=fill_value)
  182. return F.grid_sample(src, grid, align_corners=align_corners, mode=mode, padding_mode=padding_mode)
  183. def _fill_and_warp(src: Tensor, grid: Tensor, mode: str, align_corners: bool, fill_value: Tensor) -> Tensor:
  184. r"""Warp a mask of ones, then multiple with fill_value and add to default warp.
  185. Args:
  186. src: input tensor of shape :math:`(B, 3, H, W)`.
  187. grid: grid tensor from `transform_points`.
  188. mode: interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'``.
  189. align_corners: interpolation flag.
  190. fill_value: tensor of shape :math:`(3)` that fills the padding area. Only supported for RGB.
  191. Returns:
  192. the warped and filled tensor with shape :math:`(B, 3, H, W)`.
  193. """
  194. ones_mask = ones_like(src)
  195. fill_value = fill_value.to(ones_mask)[None, :, None, None] # cast and add dimensions for broadcasting
  196. inv_ones_mask = 1 - F.grid_sample(ones_mask, grid, align_corners=align_corners, mode=mode, padding_mode="zeros")
  197. inv_color_mask = inv_ones_mask * fill_value
  198. return F.grid_sample(src, grid, align_corners=align_corners, mode=mode, padding_mode="zeros") + inv_color_mask
  199. def warp_grid(grid: Tensor, src_homo_dst: Tensor) -> Tensor:
  200. r"""Compute the grid to warp the coordinates grid by the homography/ies.
  201. Args:
  202. grid: Unwrapped grid of the shape :math:`(1, H, W, 2)`.
  203. src_homo_dst: Homography or homographies (stacked) to
  204. transform all points in the grid. Shape of the homography
  205. has to be :math:`(1, 3, 3)` or :math:`(N, 1, 3, 3)`.
  206. Returns:
  207. the transformed grid of shape :math:`(N, H, W, 2)`.
  208. """
  209. batch_size: int = src_homo_dst.size(0)
  210. _, height, width, _ = grid.size()
  211. # expand grid to match the input batch size
  212. grid = grid.expand(batch_size, -1, -1, -1) # NxHxWx2
  213. if len(src_homo_dst.shape) == 3: # local homography case
  214. src_homo_dst = src_homo_dst.view(batch_size, 1, 3, 3) # Nx1x3x3
  215. # perform the actual grid transformation,
  216. # the grid is copied to input device and casted to the same type
  217. flow: Tensor = transform_points(src_homo_dst, grid.to(src_homo_dst)) # NxHxWx2
  218. return flow.view(batch_size, height, width, 2) # NxHxWx2
  219. def warp_grid3d(grid: Tensor, src_homo_dst: Tensor) -> Tensor:
  220. r"""Compute the grid to warp the coordinates grid by the homography/ies.
  221. Args:
  222. grid: Unwrapped grid of the shape :math:`(1, D, H, W, 3)`.
  223. src_homo_dst: Homography or homographies (stacked) to
  224. transform all points in the grid. Shape of the homography
  225. has to be :math:`(1, 4, 4)` or :math:`(N, 1, 4, 4)`.
  226. Returns:
  227. the transformed grid of shape :math:`(N, H, W, 3)`.
  228. """
  229. batch_size: int = src_homo_dst.size(0)
  230. _, depth, height, width, _ = grid.size()
  231. # expand grid to match the input batch size
  232. grid = grid.expand(batch_size, -1, -1, -1, -1) # NxDxHxWx3
  233. if len(src_homo_dst.shape) == 3: # local homography case
  234. src_homo_dst = src_homo_dst.view(batch_size, 1, 4, 4) # Nx1x3x3
  235. # perform the actual grid transformation,
  236. # the grid is copied to input device and casted to the same type
  237. flow: Tensor = transform_points(src_homo_dst, grid.to(src_homo_dst)) # NxDxHxWx3
  238. return flow.view(batch_size, depth, height, width, 3) # NxDxHxWx3
  239. # TODO: move to kornia.geometry.projective
  240. # TODO: create the nn.Module -- TBD what inputs/outputs etc
  241. # class PerspectiveTransform(nn.Module):
  242. # def __init__(self) -> None:
  243. # super().__init__()
  244. def get_perspective_transform(points_src: Tensor, points_dst: Tensor) -> Tensor:
  245. r"""Calculate a perspective transform from four pairs of the corresponding points.
  246. The algorithm is a vanilla implementation of the Direct Linear transform (DLT).
  247. See more: https://www.cs.cmu.edu/~16385/s17/Slides/10.2_2D_Alignment__DLT.pdf
  248. The function calculates the matrix of a perspective transform that maps from
  249. the source to destination points:
  250. .. math::
  251. \begin{bmatrix}
  252. x^{'} \\
  253. y^{'} \\
  254. 1 \\
  255. \end{bmatrix}
  256. =
  257. \begin{bmatrix}
  258. h_1 & h_2 & h_3 \\
  259. h_4 & h_5 & h_6 \\
  260. h_7 & h_8 & h_9 \\
  261. \end{bmatrix}
  262. \cdot
  263. \begin{bmatrix}
  264. x \\
  265. y \\
  266. 1 \\
  267. \end{bmatrix}
  268. Args:
  269. points_src: coordinates of quadrangle vertices in the source image with shape :math:`(B, 4, 2)`.
  270. points_dst: coordinates of the corresponding quadrangle vertices in
  271. the destination image with shape :math:`(B, 4, 2)`.
  272. Returns:
  273. the perspective transformation with shape :math:`(B, 3, 3)`.
  274. .. note::
  275. This function is often used in conjunction with :func:`warp_perspective`.
  276. Example:
  277. >>> x1 = torch.tensor([[[0., 0.], [1., 0.], [1., 1.], [0., 1.]]])
  278. >>> x2 = torch.tensor([[[1., 0.], [0., 0.], [0., 1.], [1., 1.]]])
  279. >>> x2_trans_x1 = get_perspective_transform(x1, x2)
  280. """
  281. KORNIA_CHECK_SHAPE(points_src, ["B", "4", "2"])
  282. KORNIA_CHECK_SHAPE(points_dst, ["B", "4", "2"])
  283. KORNIA_CHECK(points_src.shape == points_dst.shape, "Source data shape must match Destination data shape.")
  284. KORNIA_CHECK(points_src.dtype == points_dst.dtype, "Source data type must match Destination data type.")
  285. # we build matrix A by using only 4 point correspondence. The linear
  286. # system is solved with the least square method, so here
  287. # we could even pass more correspondence
  288. # create the lhs tensor with shape # Bx8x8
  289. B: int = points_src.shape[0] # batch_size
  290. A = torch.empty(B, 8, 8, device=points_src.device, dtype=points_src.dtype)
  291. # we need to perform in batch
  292. _zeros = zeros(B, device=points_src.device, dtype=points_src.dtype)
  293. _ones = ones(B, device=points_src.device, dtype=points_src.dtype)
  294. for i in range(4):
  295. x1, y1 = points_src[..., i, 0], points_src[..., i, 1] # Bx4
  296. x2, y2 = points_dst[..., i, 0], points_dst[..., i, 1] # Bx4
  297. A[:, 2 * i] = stack([x1, y1, _ones, _zeros, _zeros, _zeros, -x1 * x2, -y1 * x2], -1)
  298. A[:, 2 * i + 1] = stack([_zeros, _zeros, _zeros, x1, y1, _ones, -x1 * y2, -y1 * y2], -1)
  299. # the rhs tensor
  300. b = points_dst.view(-1, 8, 1)
  301. # solve the system Ax = b
  302. X: Tensor = _torch_solve_cast(A, b)
  303. # create variable to return the Bx3x3 transform
  304. M = torch.empty(B, 9, device=points_src.device, dtype=points_src.dtype)
  305. M[..., :8] = X[..., 0] # Bx8
  306. M[..., -1].fill_(1)
  307. return M.view(-1, 3, 3) # Bx3x3
  308. # TODO: move to kornia.geometry.affine
  309. def get_rotation_matrix2d(center: Tensor, angle: Tensor, scale: Tensor) -> Tensor:
  310. r"""Calculate an affine matrix of 2D rotation.
  311. The function calculates the following matrix:
  312. .. math::
  313. \begin{bmatrix}
  314. \alpha & \beta & (1 - \alpha) \cdot \text{x}
  315. - \beta \cdot \text{y} \\
  316. -\beta & \alpha & \beta \cdot \text{x}
  317. + (1 - \alpha) \cdot \text{y}
  318. \end{bmatrix}
  319. where
  320. .. math::
  321. \alpha = \text{scale} \cdot cos(\text{angle}) \\
  322. \beta = \text{scale} \cdot sin(\text{angle})
  323. The transformation maps the rotation center to itself
  324. If this is not the target, adjust the shift.
  325. Args:
  326. center: center of the rotation in the source image with shape :math:`(B, 2)`.
  327. angle: rotation angle in degrees. Positive values mean
  328. counter-clockwise rotation (the coordinate origin is assumed to
  329. be the top-left corner) with shape :math:`(B)`.
  330. scale: scale factor for x, y scaling with shape :math:`(B, 2)`.
  331. Returns:
  332. the affine matrix of 2D rotation with shape :math:`(B, 2, 3)`.
  333. Example:
  334. >>> center = zeros(1, 2)
  335. >>> scale = torch.ones((1, 2))
  336. >>> angle = 45. * torch.ones(1)
  337. >>> get_rotation_matrix2d(center, angle, scale)
  338. tensor([[[ 0.7071, 0.7071, 0.0000],
  339. [-0.7071, 0.7071, 0.0000]]])
  340. .. note::
  341. This function is often used in conjunction with :func:`warp_affine`.
  342. """
  343. if not isinstance(center, Tensor):
  344. raise TypeError(f"Input center type is not a Tensor. Got {type(center)}")
  345. if not isinstance(angle, Tensor):
  346. raise TypeError(f"Input angle type is not a Tensor. Got {type(angle)}")
  347. if not isinstance(scale, Tensor):
  348. raise TypeError(f"Input scale type is not a Tensor. Got {type(scale)}")
  349. if not (len(center.shape) == 2 and center.shape[1] == 2):
  350. raise ValueError(f"Input center must be a Bx2 tensor. Got {center.shape}")
  351. if not len(angle.shape) == 1:
  352. raise ValueError(f"Input angle must be a B tensor. Got {angle.shape}")
  353. if not (len(scale.shape) == 2 and scale.shape[1] == 2):
  354. raise ValueError(f"Input scale must be a Bx2 tensor. Got {scale.shape}")
  355. if not (center.shape[0] == angle.shape[0] == scale.shape[0]):
  356. raise ValueError(
  357. f"Inputs must have same batch size dimension. Got center {center.shape}, angle {angle.shape} and scale "
  358. f"{scale.shape}"
  359. )
  360. if not (center.device == angle.device == scale.device) or not (center.dtype == angle.dtype == scale.dtype):
  361. raise ValueError(
  362. f"Inputs must have same device Got center ({center.device}, {center.dtype}), angle ({angle.device}, "
  363. f"{angle.dtype}) and scale ({scale.device}, {scale.dtype})"
  364. )
  365. shift_m = eye_like(3, center)
  366. shift_m[:, :2, 2] = center
  367. shift_m_inv = eye_like(3, center)
  368. shift_m_inv[:, :2, 2] = -center
  369. scale_m = eye_like(3, center)
  370. scale_m[:, 0, 0] *= scale[:, 0]
  371. scale_m[:, 1, 1] *= scale[:, 1]
  372. rotat_m = eye_like(3, center)
  373. rotat_m[:, :2, :2] = angle_to_rotation_matrix(angle)
  374. affine_m = shift_m @ rotat_m @ scale_m @ shift_m_inv
  375. return affine_m[:, :2, :] # Bx2x3
  376. def remap(
  377. image: Tensor,
  378. map_x: Tensor,
  379. map_y: Tensor,
  380. mode: str = "bilinear",
  381. padding_mode: str = "zeros",
  382. align_corners: Optional[bool] = None,
  383. normalized_coordinates: bool = False,
  384. ) -> Tensor:
  385. r"""Apply a generic geometrical transformation to an image tensor.
  386. .. image:: _static/img/remap.png
  387. The function remap transforms the source tensor using the specified map:
  388. .. math::
  389. \text{dst}(x, y) = \text{src}(map_x(x, y), map_y(x, y))
  390. Args:
  391. image: the tensor to remap with shape (B, C, H, W).
  392. Where C is the number of channels.
  393. map_x: the flow in the x-direction in pixel coordinates.
  394. The tensor must be in the shape of (B, H, W).
  395. map_y: the flow in the y-direction in pixel coordinates.
  396. The tensor must be in the shape of (B, H, W).
  397. mode: interpolation mode to calculate output values
  398. ``'bilinear'`` | ``'nearest'``.
  399. padding_mode: padding mode for outside grid values
  400. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  401. align_corners: mode for grid_generation.
  402. normalized_coordinates: whether the input coordinates are
  403. normalized in the range of [-1, 1].
  404. Returns:
  405. the warped tensor with same shape as the input grid maps.
  406. Example:
  407. >>> import torch
  408. >>> from kornia.utils import create_meshgrid
  409. >>> grid = create_meshgrid(2, 2, False) # 1x2x2x2
  410. >>> grid += 1 # apply offset in both directions
  411. >>> input = torch.ones(1, 1, 2, 2)
  412. >>> remap(input, grid[..., 0], grid[..., 1], align_corners=True) # 1x1x2x2
  413. tensor([[[[1., 0.],
  414. [0., 0.]]]])
  415. .. note::
  416. This function is often used in conjunction with :func:`kornia.utils.create_meshgrid`.
  417. """
  418. KORNIA_CHECK_SHAPE(image, ["B", "C", "H", "W"])
  419. KORNIA_CHECK_SHAPE(map_x, ["B", "H", "W"])
  420. KORNIA_CHECK_SHAPE(map_y, ["B", "H", "W"])
  421. batch_size, _, height, width = image.shape
  422. # grid_sample need the grid between -1/1
  423. map_xy: Tensor = stack([map_x, map_y], -1)
  424. # normalize coordinates if not already normalized
  425. if not normalized_coordinates:
  426. map_xy = normalize_pixel_coordinates(map_xy, height, width)
  427. # simulate broadcasting since grid_sample does not support it
  428. map_xy = map_xy.expand(batch_size, -1, -1, -1)
  429. # warp the image tensor and return
  430. return F.grid_sample(image, map_xy, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
  431. def invert_affine_transform(matrix: Tensor) -> Tensor:
  432. r"""Invert an affine transformation.
  433. The function computes an inverse affine transformation represented by
  434. 2x3 matrix:
  435. .. math::
  436. \begin{bmatrix}
  437. a_{11} & a_{12} & b_{1} \\
  438. a_{21} & a_{22} & b_{2} \\
  439. \end{bmatrix}
  440. The result is also a 2x3 matrix of the same type as M.
  441. Args:
  442. matrix: original affine transform. The tensor must be
  443. in the shape of :math:`(B, 2, 3)`.
  444. Return:
  445. the reverse affine transform with shape :math:`(B, 2, 3)`.
  446. .. note::
  447. This function is often used in conjunction with :func:`warp_affine`.
  448. """
  449. if not isinstance(matrix, Tensor):
  450. raise TypeError(f"Input matrix type is not a Tensor. Got {type(matrix)}")
  451. if not (len(matrix.shape) == 3 and matrix.shape[-2:] == (2, 3)):
  452. raise ValueError(f"Input matrix must be a Bx2x3 tensor. Got {matrix.shape}")
  453. matrix_tmp: Tensor = convert_affinematrix_to_homography(matrix)
  454. matrix_inv: Tensor = _torch_inverse_cast(matrix_tmp)
  455. return matrix_inv[..., :2, :3]
  456. def get_affine_matrix2d(
  457. translations: Tensor,
  458. center: Tensor,
  459. scale: Tensor,
  460. angle: Tensor,
  461. sx: Optional[Tensor] = None,
  462. sy: Optional[Tensor] = None,
  463. ) -> Tensor:
  464. r"""Compose affine matrix from the components.
  465. Args:
  466. translations: tensor containing the translation vector with shape :math:`(B, 2)`.
  467. center: tensor containing the center vector with shape :math:`(B, 2)`.
  468. scale: tensor containing the scale factor with shape :math:`(B, 2)`.
  469. angle: tensor of angles in degrees :math:`(B)`.
  470. sx: tensor containing the shear factor in the x-direction with shape :math:`(B)`.
  471. sy: tensor containing the shear factor in the y-direction with shape :math:`(B)`.
  472. Returns:
  473. the affine transformation matrix :math:`(B, 3, 3)`.
  474. .. note::
  475. This function is often used in conjunction with :func:`warp_affine`, :func:`warp_perspective`.
  476. """
  477. transform: Tensor = get_rotation_matrix2d(center, -angle, scale)
  478. transform[..., 2] += translations # tx/ty
  479. # pad transform to get Bx3x3
  480. transform_h = convert_affinematrix_to_homography(transform)
  481. if any(s is not None for s in [sx, sy]):
  482. shear_mat = get_shear_matrix2d(center, sx, sy)
  483. transform_h = transform_h @ shear_mat
  484. return transform_h
  485. def get_translation_matrix2d(translations: Tensor) -> Tensor:
  486. r"""Compose translation matrix from the components.
  487. Args:
  488. translations: tensor containing the translation vector with shape :math:`(B, 2)`.
  489. Returns:
  490. the affine transformation matrix :math:`(B, 3, 3)`.
  491. .. note::
  492. This function is often used in conjunction with :func:`warp_affine`, :func:`warp_perspective`.
  493. """
  494. transform: Tensor = eye_like(3, translations)[:, :2, :]
  495. transform[..., 2] += translations # tx/ty
  496. # pad transform to get Bx3x3
  497. transform_h = convert_affinematrix_to_homography(transform)
  498. return transform_h
  499. def get_shear_matrix2d(center: Tensor, sx: Optional[Tensor] = None, sy: Optional[Tensor] = None) -> Tensor:
  500. r"""Compose shear matrix Bx4x4 from the components.
  501. Note: Ordered shearing, shear x-axis then y-axis.
  502. .. math::
  503. \begin{bmatrix}
  504. 1 & b \\
  505. a & ab + 1 \\
  506. \end{bmatrix}
  507. Args:
  508. center: shearing center coordinates of (x, y).
  509. sx: shearing angle along x axis in radiants.
  510. sy: shearing angle along y axis in radiants
  511. Returns:
  512. params to be passed to the affine transformation with shape :math:`(B, 3, 3)`.
  513. Examples:
  514. >>> rng = torch.manual_seed(0)
  515. >>> sx = torch.randn(1)
  516. >>> sx
  517. tensor([1.5410])
  518. >>> center = torch.tensor([[0., 0.]]) # Bx2
  519. >>> get_shear_matrix2d(center, sx=sx)
  520. tensor([[[ 1.0000, -33.5468, 0.0000],
  521. [ -0.0000, 1.0000, 0.0000],
  522. [ 0.0000, 0.0000, 1.0000]]])
  523. .. note::
  524. This function is often used in conjunction with :func:`warp_affine`, :func:`warp_perspective`.
  525. """
  526. sx = tensor([0.0]).repeat(center.size(0)) if sx is None else sx
  527. sy = tensor([0.0]).repeat(center.size(0)) if sy is None else sy
  528. x, y = torch.split(center, 1, dim=-1)
  529. x, y = x.view(-1), y.view(-1)
  530. sx_tan = tan(sx)
  531. sy_tan = tan(sy)
  532. ones = ones_like(sx)
  533. shear_mat = stack(
  534. [ones, -sx_tan, sx_tan * y, -sy_tan, ones + sx_tan * sy_tan, sy_tan * (x - sx_tan * y)], dim=-1
  535. ).view(-1, 2, 3)
  536. shear_mat = convert_affinematrix_to_homography(shear_mat)
  537. return shear_mat
  538. def get_affine_matrix3d(
  539. translations: Tensor,
  540. center: Tensor,
  541. scale: Tensor,
  542. angles: Tensor,
  543. sxy: Optional[Tensor] = None,
  544. sxz: Optional[Tensor] = None,
  545. syx: Optional[Tensor] = None,
  546. syz: Optional[Tensor] = None,
  547. szx: Optional[Tensor] = None,
  548. szy: Optional[Tensor] = None,
  549. ) -> Tensor:
  550. r"""Compose 3d affine matrix from the components.
  551. Args:
  552. translations: tensor containing the translation vector (dx,dy,dz) with shape :math:`(B, 3)`.
  553. center: tensor containing the center vector (x,y,z) with shape :math:`(B, 3)`.
  554. scale: tensor containing the scale factor with shape :math:`(B)`.
  555. angles: axis angle vector containing the rotation angles in degrees in the form
  556. of (rx, ry, rz) with shape :math:`(B, 3)`. Internally it calls Rodrigues to compute
  557. the rotation matrix from axis-angle.
  558. sxy: tensor containing the shear factor in the xy-direction with shape :math:`(B)`.
  559. sxz: tensor containing the shear factor in the xz-direction with shape :math:`(B)`.
  560. syx: tensor containing the shear factor in the yx-direction with shape :math:`(B)`.
  561. syz: tensor containing the shear factor in the yz-direction with shape :math:`(B)`.
  562. szx: tensor containing the shear factor in the zx-direction with shape :math:`(B)`.
  563. szy: tensor containing the shear factor in the zy-direction with shape :math:`(B)`.
  564. Returns:
  565. the 3d affine transformation matrix :math:`(B, 3, 3)`.
  566. .. note::
  567. This function is often used in conjunction with :func:`warp_perspective`.
  568. """
  569. transform: Tensor = get_projective_transform(center, -angles, scale)
  570. transform[..., 3] += translations # tx/ty/tz
  571. # pad transform to get Bx3x3
  572. transform_h = convert_affinematrix_to_homography3d(transform)
  573. if any(s is not None for s in [sxy, sxz, syx, syz, szx, szy]):
  574. shear_mat = get_shear_matrix3d(center, sxy, sxz, syx, syz, szx, szy)
  575. transform_h = transform_h @ shear_mat
  576. return transform_h
  577. def get_shear_matrix3d(
  578. center: Tensor,
  579. sxy: Optional[Tensor] = None,
  580. sxz: Optional[Tensor] = None,
  581. syx: Optional[Tensor] = None,
  582. syz: Optional[Tensor] = None,
  583. szx: Optional[Tensor] = None,
  584. szy: Optional[Tensor] = None,
  585. ) -> Tensor:
  586. r"""Compose shear matrix Bx4x4 from the components.
  587. Note: Ordered shearing, shear x-axis then y-axis then z-axis.
  588. .. math::
  589. \begin{bmatrix}
  590. 1 & o & r & oy + rz \\
  591. m & p & s & mx + py + sz -y \\
  592. n & q & t & nx + qy + tz -z \\
  593. 0 & 0 & 0 & 1 \\
  594. \end{bmatrix}
  595. Where:
  596. m = S_{xy}
  597. n = S_{xz}
  598. o = S_{yx}
  599. p = S_{xy}S_{yx} + 1
  600. q = S_{xz}S_{yx} + S_{yz}
  601. r = S_{zx} + S_{yx}S_{zy}
  602. s = S_{xy}S_{zx} + (S_{xy}S_{yx} + 1)S_{zy}
  603. t = S_{xz}S_{zx} + (S_{xz}S_{yx} + S_{yz})S_{zy} + 1
  604. Params:
  605. center: shearing center coordinates of (x, y, z).
  606. sxy: shearing angle along x axis, towards y plane in radiants.
  607. sxz: shearing angle along x axis, towards z plane in radiants.
  608. syx: shearing angle along y axis, towards x plane in radiants.
  609. syz: shearing angle along y axis, towards z plane in radiants.
  610. szx: shearing angle along z axis, towards x plane in radiants.
  611. szy: shearing angle along z axis, towards y plane in radiants.
  612. Returns:
  613. params to be passed to the affine transformation.
  614. Examples:
  615. >>> rng = torch.manual_seed(0)
  616. >>> sxy, sxz, syx, syz = torch.randn(4, 1)
  617. >>> sxy, sxz, syx, syz
  618. (tensor([1.5410]), tensor([-0.2934]), tensor([-2.1788]), tensor([0.5684]))
  619. >>> center = torch.tensor([[0., 0., 0.]]) # Bx3
  620. >>> get_shear_matrix3d(center, sxy=sxy, sxz=sxz, syx=syx, syz=syz)
  621. tensor([[[ 1.0000, -1.4369, 0.0000, 0.0000],
  622. [-33.5468, 49.2039, 0.0000, 0.0000],
  623. [ 0.3022, -1.0729, 1.0000, 0.0000],
  624. [ 0.0000, 0.0000, 0.0000, 1.0000]]])
  625. .. note::
  626. This function is often used in conjunction with :func:`warp_perspective3d`.
  627. """
  628. sxy = tensor([0.0]).repeat(center.size(0)) if sxy is None else sxy
  629. sxz = tensor([0.0]).repeat(center.size(0)) if sxz is None else sxz
  630. syx = tensor([0.0]).repeat(center.size(0)) if syx is None else syx
  631. syz = tensor([0.0]).repeat(center.size(0)) if syz is None else syz
  632. szx = tensor([0.0]).repeat(center.size(0)) if szx is None else szx
  633. szy = tensor([0.0]).repeat(center.size(0)) if szy is None else szy
  634. x, y, z = torch.split(center, 1, dim=-1)
  635. x, y, z = x.view(-1), y.view(-1), z.view(-1)
  636. # Prepare parameters
  637. sxy_tan = tan(sxy)
  638. sxz_tan = tan(sxz)
  639. syx_tan = tan(syx)
  640. syz_tan = tan(syz)
  641. szx_tan = tan(szx)
  642. szy_tan = tan(szy)
  643. # compute translation matrix
  644. m00, m10, m20, m01, m11, m21, m02, m12, m22 = _compute_shear_matrix_3d(
  645. sxy_tan, sxz_tan, syx_tan, syz_tan, szx_tan, szy_tan
  646. )
  647. m03 = m01 * y + m02 * z
  648. m13 = m10 * x + m11 * y + m12 * z - y
  649. m23 = m20 * x + m21 * y + m22 * z - z
  650. # shear matrix is implemented with negative values
  651. sxy_tan, sxz_tan, syx_tan, syz_tan, szx_tan, szy_tan = -sxy_tan, -sxz_tan, -syx_tan, -syz_tan, -szx_tan, -szy_tan
  652. m00, m10, m20, m01, m11, m21, m02, m12, m22 = _compute_shear_matrix_3d(
  653. sxy_tan, sxz_tan, syx_tan, syz_tan, szx_tan, szy_tan
  654. )
  655. shear_mat = stack([m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23], -1).view(-1, 3, 4)
  656. shear_mat = convert_affinematrix_to_homography3d(shear_mat)
  657. return shear_mat
  658. def _compute_shear_matrix_3d(
  659. sxy_tan: Tensor, sxz_tan: Tensor, syx_tan: Tensor, syz_tan: Tensor, szx_tan: Tensor, szy_tan: Tensor
  660. ) -> tuple[Tensor, ...]:
  661. ones = ones_like(sxy_tan)
  662. m00, m10, m20 = ones, sxy_tan, sxz_tan
  663. m01, m11, m21 = syx_tan, sxy_tan * syx_tan + ones, sxz_tan * syx_tan + syz_tan
  664. m02 = syx_tan * szy_tan + szx_tan
  665. m12 = sxy_tan * szx_tan + szy_tan * m11
  666. m22 = sxz_tan * szx_tan + szy_tan * m21 + ones
  667. return m00, m10, m20, m01, m11, m21, m02, m12, m22
  668. def warp_affine3d(
  669. src: Tensor,
  670. M: Tensor,
  671. dsize: tuple[int, int, int],
  672. flags: str = "bilinear",
  673. padding_mode: str = "zeros",
  674. align_corners: bool = True,
  675. ) -> Tensor:
  676. r"""Apply a projective transformation a to 3d tensor.
  677. .. warning::
  678. This API signature it is experimental and might suffer some changes in the future.
  679. Args:
  680. src : input tensor of shape :math:`(B, C, D, H, W)`.
  681. M: projective transformation matrix of shape :math:`(B, 3, 4)`.
  682. dsize: size of the output image (depth, height, width).
  683. flags: interpolation mode to calculate output values
  684. ``'bilinear'`` | ``'nearest'``.
  685. padding_mode: padding mode for outside grid values
  686. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  687. align_corners : mode for grid_generation.
  688. Returns:
  689. Tensor: the warped 3d tensor with shape :math:`(B, C, D, H, W)`.
  690. .. note::
  691. This function is often used in conjunction with :func:`get_perspective_transform3d`.
  692. """
  693. if len(src.shape) != 5:
  694. raise AssertionError(src.shape)
  695. if not (len(M.shape) == 3 and M.shape[-2:] == (3, 4)):
  696. raise AssertionError(M.shape)
  697. if len(dsize) != 3:
  698. raise AssertionError(dsize)
  699. B, C, D, H, W = src.size()
  700. size_src: tuple[int, int, int] = (D, H, W)
  701. size_out: tuple[int, int, int] = dsize
  702. M_4x4 = convert_affinematrix_to_homography3d(M) # Bx4x4
  703. # we need to normalize the transformation since grid sample needs -1/1 coordinates
  704. dst_norm_trans_src_norm: Tensor = normalize_homography3d(M_4x4, size_src, size_out) # Bx4x4
  705. src_norm_trans_dst_norm = _torch_inverse_cast(dst_norm_trans_src_norm)
  706. P_norm: Tensor = src_norm_trans_dst_norm[:, :3] # Bx3x4
  707. # compute meshgrid and apply to input
  708. dsize_out: list[int] = [B, C, *list(size_out)]
  709. grid = F.affine_grid(P_norm, dsize_out, align_corners=align_corners)
  710. return F.grid_sample(src, grid, align_corners=align_corners, mode=flags, padding_mode=padding_mode)
  711. def projection_from_Rt(rmat: Tensor, tvec: Tensor) -> Tensor:
  712. r"""Compute the projection matrix from Rotation and translation.
  713. .. warning::
  714. This API signature it is experimental and might suffer some changes in the future.
  715. Concatenates the batch of rotations and translations such that :math:`P = [R | t]`.
  716. Args:
  717. rmat: the rotation matrix with shape :math:`(*, 3, 3)`.
  718. tvec: the translation vector with shape :math:`(*, 3, 1)`.
  719. Returns:
  720. the projection matrix with shape :math:`(*, 3, 4)`.
  721. """
  722. if not (len(rmat.shape) >= 2 and rmat.shape[-2:] == (3, 3)):
  723. raise AssertionError(rmat.shape)
  724. if not (len(tvec.shape) >= 2 and tvec.shape[-2:] == (3, 1)):
  725. raise AssertionError(tvec.shape)
  726. return concatenate([rmat, tvec], -1) # Bx3x4
  727. def get_projective_transform(center: Tensor, angles: Tensor, scales: Tensor) -> Tensor:
  728. r"""Calculate the projection matrix for a 3D rotation.
  729. .. warning::
  730. This API signature it is experimental and might suffer some changes in the future.
  731. The function computes the projection matrix given the center and angles per axis.
  732. Args:
  733. center: center of the rotation (x,y,z) in the source with shape :math:`(B, 3)`.
  734. angles: axis angle vector containing the rotation angles in degrees in the form
  735. of (rx, ry, rz) with shape :math:`(B, 3)`. Internally it calls Rodrigues to compute
  736. the rotation matrix from axis-angle.
  737. scales: scale factor for x-y-z-directions with shape :math:`(B, 3)`.
  738. Returns:
  739. the projection matrix of 3D rotation with shape :math:`(B, 3, 4)`.
  740. .. note::
  741. This function is often used in conjunction with :func:`warp_affine3d`.
  742. """
  743. if not (len(center.shape) == 2 and center.shape[-1] == 3):
  744. raise AssertionError(center.shape)
  745. if not (len(angles.shape) == 2 and angles.shape[-1] == 3):
  746. raise AssertionError(angles.shape)
  747. if center.device != angles.device:
  748. raise AssertionError(center.device, angles.device)
  749. if center.dtype != angles.dtype:
  750. raise AssertionError(center.dtype, angles.dtype)
  751. # create rotation matrix
  752. axis_angle_rad: Tensor = deg2rad(angles)
  753. rmat: Tensor = axis_angle_to_rotation_matrix(axis_angle_rad) # Bx3x3
  754. scaling_matrix: Tensor = eye_like(3, rmat)
  755. scaling_matrix = scaling_matrix * scales.unsqueeze(dim=1)
  756. rmat = rmat @ scaling_matrix.to(rmat)
  757. # define matrix to move forth and back to origin
  758. from_origin_mat = eye_like(4, rmat, shared_memory=False) # Bx4x4
  759. from_origin_mat[..., :3, -1] += center
  760. to_origin_mat = from_origin_mat.clone()
  761. to_origin_mat = _torch_inverse_cast(from_origin_mat)
  762. # append translation with zeros
  763. proj_mat = projection_from_Rt(rmat, torch.zeros_like(center)[..., None]) # Bx3x4
  764. # chain 4x4 transforms
  765. proj_mat = convert_affinematrix_to_homography3d(proj_mat) # Bx4x4
  766. proj_mat = from_origin_mat @ proj_mat @ to_origin_mat
  767. return proj_mat[..., :3, :] # Bx3x4
  768. def get_perspective_transform3d(src: Tensor, dst: Tensor) -> Tensor:
  769. r"""Calculate a 3d perspective transform from four pairs of the corresponding points.
  770. The function calculates the matrix of a perspective transform so that:
  771. .. math::
  772. \begin{bmatrix}
  773. t_{i}x_{i}^{'} \\
  774. t_{i}y_{i}^{'} \\
  775. t_{i}z_{i}^{'} \\
  776. t_{i} \\
  777. \end{bmatrix}
  778. =
  779. \textbf{map_matrix} \cdot
  780. \begin{bmatrix}
  781. x_{i} \\
  782. y_{i} \\
  783. z_{i} \\
  784. 1 \\
  785. \end{bmatrix}
  786. where
  787. .. math::
  788. dst(i) = (x_{i}^{'},y_{i}^{'},z_{i}^{'}), src(i) = (x_{i}, y_{i}, z_{i}), i = 0,1,2,5,7
  789. Concrete math is as below:
  790. .. math::
  791. \[ u_i =\frac{c_{00} * x_i + c_{01} * y_i + c_{02} * z_i + c_{03}}
  792. {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \]
  793. \[ v_i =\frac{c_{10} * x_i + c_{11} * y_i + c_{12} * z_i + c_{13}}
  794. {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \]
  795. \[ w_i =\frac{c_{20} * x_i + c_{21} * y_i + c_{22} * z_i + c_{23}}
  796. {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \]
  797. .. math::
  798. \begin{pmatrix}
  799. x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_0*u_0 & -y_0*u_0 & -z_0 * u_0 \\
  800. x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_1*u_1 & -y_1*u_1 & -z_1 * u_1 \\
  801. x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_2*u_2 & -y_2*u_2 & -z_2 * u_2 \\
  802. x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_5*u_5 & -y_5*u_5 & -z_5 * u_5 \\
  803. x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_7*u_7 & -y_7*u_7 & -z_7 * u_7 \\
  804. 0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & -x_0*v_0 & -y_0*v_0 & -z_0 * v_0 \\
  805. 0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & -x_1*v_1 & -y_1*v_1 & -z_1 * v_1 \\
  806. 0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & -x_2*v_2 & -y_2*v_2 & -z_2 * v_2 \\
  807. 0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & -x_5*v_5 & -y_5*v_5 & -z_5 * v_5 \\
  808. 0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & -x_7*v_7 & -y_7*v_7 & -z_7 * v_7 \\
  809. 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & -x_0*w_0 & -y_0*w_0 & -z_0 * w_0 \\
  810. 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & -x_1*w_1 & -y_1*w_1 & -z_1 * w_1 \\
  811. 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & -x_2*w_2 & -y_2*w_2 & -z_2 * w_2 \\
  812. 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & -x_5*w_5 & -y_5*w_5 & -z_5 * w_5 \\
  813. 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & -x_7*w_7 & -y_7*w_7 & -z_7 * w_7 \\
  814. \end{pmatrix}
  815. Args:
  816. src: coordinates of quadrangle vertices in the source image with shape :math:`(B, 8, 3)`.
  817. dst: coordinates of the corresponding quadrangle vertices in
  818. the destination image with shape :math:`(B, 8, 3)`.
  819. Returns:
  820. the perspective transformation with shape :math:`(B, 4, 4)`.
  821. .. note::
  822. This function is often used in conjunction with :func:`warp_perspective3d`.
  823. """
  824. if not isinstance(src, (Tensor)):
  825. raise TypeError(f"Input type is not a Tensor. Got {type(src)}")
  826. if not isinstance(dst, (Tensor)):
  827. raise TypeError(f"Input type is not a Tensor. Got {type(dst)}")
  828. if not src.shape[-2:] == (8, 3):
  829. raise ValueError(f"Inputs must be a Bx8x3 tensor. Got {src.shape}")
  830. if not src.shape == dst.shape:
  831. raise ValueError(f"Inputs must have the same shape. Got {dst.shape}")
  832. if not (src.shape[0] == dst.shape[0]):
  833. raise ValueError(f"Inputs must have same batch size dimension. Expect {src.shape} but got {dst.shape}")
  834. if not (src.device == dst.device and src.dtype == dst.dtype):
  835. raise AssertionError(
  836. f"Expect `src` and `dst` to be in the same device (Got {src.dtype}, {dst.dtype}) "
  837. f"with the same dtype (Got {src.dtype}, {dst.dtype})."
  838. )
  839. # we build matrix A by using only 4 point correspondence. The linear
  840. # system is solved with the least square method, so here
  841. # we could even pass more correspondence
  842. p = []
  843. # 000, 100, 110, 101, 011
  844. for i in [0, 1, 2, 5, 7]:
  845. p.append(_build_perspective_param3d(src[:, i], dst[:, i], "x"))
  846. p.append(_build_perspective_param3d(src[:, i], dst[:, i], "y"))
  847. p.append(_build_perspective_param3d(src[:, i], dst[:, i], "z"))
  848. # A is Bx15x15
  849. A = stack(p, 1)
  850. # b is a Bx15x1
  851. b = stack(
  852. [
  853. dst[:, 0:1, 0],
  854. dst[:, 0:1, 1],
  855. dst[:, 0:1, 2],
  856. dst[:, 1:2, 0],
  857. dst[:, 1:2, 1],
  858. dst[:, 1:2, 2],
  859. dst[:, 2:3, 0],
  860. dst[:, 2:3, 1],
  861. dst[:, 2:3, 2],
  862. # dst[:, 3:4, 0], dst[:, 3:4, 1], dst[:, 3:4, 2],
  863. # dst[:, 4:5, 0], dst[:, 4:5, 1], dst[:, 4:5, 2],
  864. dst[:, 5:6, 0],
  865. dst[:, 5:6, 1],
  866. dst[:, 5:6, 2],
  867. # dst[:, 6:7, 0], dst[:, 6:7, 1], dst[:, 6:7, 2],
  868. dst[:, 7:8, 0],
  869. dst[:, 7:8, 1],
  870. dst[:, 7:8, 2],
  871. ],
  872. 1,
  873. )
  874. # solve the system Ax = b
  875. X: Tensor = _torch_solve_cast(A, b)
  876. # create variable to return
  877. batch_size: int = src.shape[0]
  878. M = torch.empty(batch_size, 16, device=src.device, dtype=src.dtype)
  879. M[..., :15] = X[..., 0]
  880. M[..., -1].fill_(1)
  881. return M.view(-1, 4, 4) # Bx4x4
  882. def _build_perspective_param3d(p: Tensor, q: Tensor, axis: str) -> Tensor:
  883. ones = torch.ones_like(p)[..., 0:1]
  884. zeros = torch.zeros_like(p)[..., 0:1]
  885. if axis == "x":
  886. return concatenate(
  887. [
  888. p[:, 0:1],
  889. p[:, 1:2],
  890. p[:, 2:3],
  891. ones,
  892. zeros,
  893. zeros,
  894. zeros,
  895. zeros,
  896. zeros,
  897. zeros,
  898. zeros,
  899. zeros,
  900. -p[:, 0:1] * q[:, 0:1],
  901. -p[:, 1:2] * q[:, 0:1],
  902. -p[:, 2:3] * q[:, 0:1],
  903. ],
  904. 1,
  905. )
  906. if axis == "y":
  907. return concatenate(
  908. [
  909. zeros,
  910. zeros,
  911. zeros,
  912. zeros,
  913. p[:, 0:1],
  914. p[:, 1:2],
  915. p[:, 2:3],
  916. ones,
  917. zeros,
  918. zeros,
  919. zeros,
  920. zeros,
  921. -p[:, 0:1] * q[:, 1:2],
  922. -p[:, 1:2] * q[:, 1:2],
  923. -p[:, 2:3] * q[:, 1:2],
  924. ],
  925. 1,
  926. )
  927. if axis == "z":
  928. return concatenate(
  929. [
  930. zeros,
  931. zeros,
  932. zeros,
  933. zeros,
  934. zeros,
  935. zeros,
  936. zeros,
  937. zeros,
  938. p[:, 0:1],
  939. p[:, 1:2],
  940. p[:, 2:3],
  941. ones,
  942. -p[:, 0:1] * q[:, 2:3],
  943. -p[:, 1:2] * q[:, 2:3],
  944. -p[:, 2:3] * q[:, 2:3],
  945. ],
  946. 1,
  947. )
  948. raise NotImplementedError(f"perspective params for axis `{axis}` is not implemented.")
  949. def warp_perspective3d(
  950. src: Tensor,
  951. M: Tensor,
  952. dsize: tuple[int, int, int],
  953. flags: str = "bilinear",
  954. border_mode: str = "zeros",
  955. align_corners: bool = False,
  956. ) -> Tensor:
  957. r"""Apply a perspective transformation to an image.
  958. The function warp_perspective transforms the source image using
  959. the specified matrix:
  960. .. math::
  961. \text{dst} (x, y) = \text{src} \left(
  962. \frac{M_{11} x + M_{12} y + M_{13}}{M_{31} x + M_{32} y + M_{33}} ,
  963. \frac{M_{21} x + M_{22} y + M_{23}}{M_{31} x + M_{32} y + M_{33}}
  964. \right )
  965. Args:
  966. src: input image with shape :math:`(B, C, D, H, W)`.
  967. M: transformation matrix with shape :math:`(B, 4, 4)`.
  968. dsize: size of the output image (height, width).
  969. flags: interpolation mode to calculate output values
  970. ``'bilinear'`` | ``'nearest'``.
  971. border_mode: padding mode for outside grid values
  972. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  973. align_corners: interpolation flag.
  974. Returns:
  975. the warped input image :math:`(B, C, D, H, W)`.
  976. .. note::
  977. This function is often used in conjunction with :func:`get_perspective_transform3d`.
  978. """
  979. if not isinstance(src, Tensor):
  980. raise TypeError(f"Input src type is not a Tensor. Got {type(src)}")
  981. if not isinstance(M, Tensor):
  982. raise TypeError(f"Input M type is not a Tensor. Got {type(M)}")
  983. if not len(src.shape) == 5:
  984. raise ValueError(f"Input src must be a BxCxDxHxW tensor. Got {src.shape}")
  985. if not (len(M.shape) == 3 or M.shape[-2:] == (4, 4)):
  986. raise ValueError(f"Input M must be a Bx4x4 tensor. Got {M.shape}")
  987. # launches the warper
  988. d, h, w = src.shape[-3:]
  989. return _transform_warp_impl3d(src, M, (d, h, w), dsize, flags, border_mode, align_corners)
  990. def homography_warp(
  991. patch_src: Tensor,
  992. src_homo_dst: Tensor,
  993. dsize: tuple[int, int],
  994. mode: str = "bilinear",
  995. padding_mode: str = "zeros",
  996. align_corners: bool = False,
  997. normalized_coordinates: bool = True,
  998. normalized_homography: bool = True,
  999. ) -> Tensor:
  1000. r"""Warp image patches or tensors by normalized 2D homographies.
  1001. See :class:`~kornia.geometry.warp.HomographyWarper` for details.
  1002. Args:
  1003. patch_src: The image or tensor to warp. Should be from source of shape :math:`(N, C, H, W)`.
  1004. src_homo_dst: The homography or stack of homographies from destination to source of shape :math:`(N, 3, 3)`.
  1005. dsize:
  1006. if homography normalized: The height and width of the image to warp.
  1007. if homography not normalized: size of the output image (height, width).
  1008. mode: interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'``.
  1009. padding_mode: padding mode for outside grid values ``'zeros'`` | ``'border'`` | ``'reflection'``.
  1010. align_corners: interpolation flag.
  1011. normalized_coordinates: Whether the homography assumes [-1, 1] normalized coordinates or not.
  1012. normalized_homography: show is homography normalized.
  1013. Return:
  1014. Patch sampled at locations from source to destination.
  1015. Example:
  1016. >>> input = torch.rand(1, 3, 32, 32)
  1017. >>> homography = torch.eye(3).view(1, 3, 3)
  1018. >>> output = homography_warp(input, homography, (32, 32))
  1019. Example:
  1020. >>> img = torch.rand(1, 4, 5, 6)
  1021. >>> H = torch.eye(3)[None]
  1022. >>> out = homography_warp(img, H, (4, 2), align_corners=True, normalized_homography=False)
  1023. >>> print(out.shape)
  1024. torch.Size([1, 4, 4, 2])
  1025. """
  1026. if not src_homo_dst.device == patch_src.device:
  1027. raise TypeError(
  1028. f"Patch and homography must be on the same device. Got patch.device: {patch_src.device} "
  1029. f"src_H_dst.device: {src_homo_dst.device}."
  1030. )
  1031. if normalized_homography:
  1032. height, width = dsize
  1033. grid = create_meshgrid(
  1034. height, width, normalized_coordinates=normalized_coordinates, device=patch_src.device, dtype=patch_src.dtype
  1035. )
  1036. warped_grid = warp_grid(grid, src_homo_dst)
  1037. return F.grid_sample(patch_src, warped_grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
  1038. return warp_perspective(
  1039. patch_src, src_homo_dst, dsize, mode="bilinear", padding_mode=padding_mode, align_corners=True
  1040. )
  1041. def _transform_warp_impl3d(
  1042. src: Tensor,
  1043. dst_pix_trans_src_pix: Tensor,
  1044. dsize_src: tuple[int, int, int],
  1045. dsize_dst: tuple[int, int, int],
  1046. grid_mode: str,
  1047. padding_mode: str,
  1048. align_corners: bool,
  1049. ) -> Tensor:
  1050. """Compute the transform in normalized coordinates and perform the warping."""
  1051. dst_norm_trans_src_norm: Tensor = normalize_homography3d(dst_pix_trans_src_pix, dsize_src, dsize_dst)
  1052. src_norm_trans_dst_norm = torch.linalg.inv(dst_norm_trans_src_norm)
  1053. return homography_warp3d(src, src_norm_trans_dst_norm, dsize_dst, grid_mode, padding_mode, align_corners, True)
  1054. def homography_warp3d(
  1055. patch_src: Tensor,
  1056. src_homo_dst: Tensor,
  1057. dsize: tuple[int, int, int],
  1058. mode: str = "bilinear",
  1059. padding_mode: str = "zeros",
  1060. align_corners: bool = False,
  1061. normalized_coordinates: bool = True,
  1062. ) -> Tensor:
  1063. r"""Warp image patches or tensors by normalized 3D homographies.
  1064. Args:
  1065. patch_src: The image or tensor to warp. Should be from source of shape :math:`(N, C, D, H, W)`.
  1066. src_homo_dst: The homography or stack of homographies from destination to source of shape
  1067. :math:`(N, 4, 4)`.
  1068. dsize: The height and width of the image to warp.
  1069. mode: interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'``.
  1070. padding_mode: padding mode for outside grid values ``'zeros'`` | ``'border'`` | ``'reflection'``.
  1071. align_corners: interpolation flag.
  1072. normalized_coordinates: Whether the homography assumes [-1, 1] normalized coordinates or not.
  1073. Return:
  1074. Patch sampled at locations from source to destination.
  1075. Example:
  1076. >>> input = torch.rand(1, 3, 32, 32)
  1077. >>> homography = torch.eye(3).view(1, 3, 3)
  1078. >>> output = homography_warp(input, homography, (32, 32))
  1079. """
  1080. if not src_homo_dst.device == patch_src.device:
  1081. raise TypeError(
  1082. f"Patch and homography must be on the same device. Got patch.device: {patch_src.device} "
  1083. f"src_H_dst.device: {src_homo_dst.device}."
  1084. )
  1085. depth, height, width = dsize
  1086. grid = create_meshgrid3d(
  1087. depth, height, width, normalized_coordinates=normalized_coordinates, device=patch_src.device
  1088. )
  1089. warped_grid = warp_grid3d(grid, src_homo_dst)
  1090. return F.grid_sample(patch_src, warped_grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners)