laf.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import math
  18. from typing import List, Optional, Tuple, Union
  19. import torch
  20. import torch.nn.functional as F
  21. from kornia.core import Tensor, concatenate, cos, sin, stack, tensor, zeros
  22. from kornia.core.check import KORNIA_CHECK_LAF, KORNIA_CHECK_SHAPE
  23. from kornia.geometry.conversions import angle_to_rotation_matrix, convert_points_from_homogeneous, rad2deg
  24. from kornia.geometry.linalg import transform_points
  25. from kornia.geometry.transform import pyrdown
  26. def get_laf_scale(LAF: Tensor) -> Tensor:
  27. """Return a scale of the LAFs.
  28. Args:
  29. LAF: :math:`(B, N, 2, 3)`
  30. Returns:
  31. scale :math:`(B, N, 1, 1)`
  32. Example:
  33. >>> input = torch.ones(1, 5, 2, 3) # BxNx2x3
  34. >>> output = get_laf_scale(input) # BxNx1x1
  35. """
  36. KORNIA_CHECK_LAF(LAF)
  37. eps = 1e-10
  38. out = LAF[..., 0:1, 0:1] * LAF[..., 1:2, 1:2] - LAF[..., 1:2, 0:1] * LAF[..., 0:1, 1:2] + eps
  39. return out.abs().sqrt()
  40. def get_laf_center(LAF: Tensor) -> Tensor:
  41. """Return a center (keypoint) of the LAFs.
  42. The convention is that center of 5-pixel image (coordinates from 0 to 4) is 2, and not 2.5.
  43. Args:
  44. LAF: :math:`(B, N, 2, 3)`
  45. Returns:
  46. xy :math:`(B, N, 2)`
  47. Example:
  48. >>> input = torch.ones(1, 5, 2, 3) # BxNx2x3
  49. >>> output = get_laf_center(input) # BxNx2
  50. """
  51. KORNIA_CHECK_LAF(LAF)
  52. out = LAF[..., 2]
  53. return out
  54. def get_laf_orientation(LAF: Tensor) -> Tensor:
  55. """Return orientation of the LAFs, in degrees.
  56. Args:
  57. LAF: :math:`(B, N, 2, 3)`
  58. Returns:
  59. angle in degrees :math:`(B, N, 1)`
  60. Example:
  61. >>> input = torch.ones(1, 5, 2, 3) # BxNx2x3
  62. >>> output = get_laf_orientation(input) # BxNx1
  63. """
  64. KORNIA_CHECK_LAF(LAF)
  65. angle_rad = torch.atan2(LAF[..., 0, 1], LAF[..., 0, 0])
  66. return rad2deg(angle_rad).unsqueeze(-1)
  67. def rotate_laf(LAF: Tensor, angles_degrees: Tensor) -> Tensor:
  68. """Apply additional rotation to the LAFs.
  69. Compared to `set_laf_orientation`, the resulting rotation is original LAF orientation plus angles_degrees.
  70. Args:
  71. LAF: :math:`(B, N, 2, 3)`
  72. angles_degrees: :math:`(B, N, 1)` in degrees.
  73. Returns:
  74. LAF oriented with angles :math:`(B, N, 2, 3)`
  75. """
  76. KORNIA_CHECK_LAF(LAF)
  77. B, N = LAF.shape[:2]
  78. rotmat = angle_to_rotation_matrix(angles_degrees).view(B * N, 2, 2)
  79. out_laf = LAF.clone()
  80. out_laf[:, :, :2, :2] = torch.bmm(LAF[:, :, :2, :2].reshape(B * N, 2, 2), rotmat).reshape(B, N, 2, 2)
  81. return out_laf
  82. def set_laf_orientation(LAF: Tensor, angles_degrees: Tensor) -> Tensor:
  83. """Change the orientation of the LAFs.
  84. Args:
  85. LAF: :math:`(B, N, 2, 3)`
  86. angles_degrees: :math:`(B, N, 1)` in degrees.
  87. Returns:
  88. LAF oriented with angles :math:`(B, N, 2, 3)`
  89. """
  90. KORNIA_CHECK_LAF(LAF)
  91. _B, _N = LAF.shape[:2]
  92. ori = get_laf_orientation(LAF).reshape_as(angles_degrees)
  93. return rotate_laf(LAF, angles_degrees - ori)
  94. def laf_from_center_scale_ori(xy: Tensor, scale: Optional[Tensor] = None, ori: Optional[Tensor] = None) -> Tensor:
  95. """Create a LAF from keypoint center, scale and orientation.
  96. Useful to create kornia LAFs from OpenCV keypoints.
  97. Args:
  98. xy: :math:`(B, N, 2)`.
  99. scale: :math:`(B, N, 1, 1)`. If not provided, scale = 1.0 is assumed
  100. ori: angle in degrees :math:`(B, N, 1)`. If not provided orientation = 0 is assumed
  101. Returns:
  102. LAF :math:`(B, N, 2, 3)`
  103. """
  104. KORNIA_CHECK_SHAPE(xy, ["B", "N", "2"])
  105. device = xy.device
  106. dtype = xy.dtype
  107. B, N = xy.shape[:2]
  108. if scale is None:
  109. scale = torch.ones(B, N, 1, 1, device=device, dtype=dtype)
  110. if ori is None:
  111. ori = zeros(B, N, 1, device=device, dtype=dtype)
  112. KORNIA_CHECK_SHAPE(scale, ["B", "N", "1", "1"])
  113. KORNIA_CHECK_SHAPE(ori, ["B", "N", "1"])
  114. unscaled_laf = concatenate([angle_to_rotation_matrix(ori.squeeze(-1)), xy.unsqueeze(-1)], dim=-1)
  115. laf = scale_laf(unscaled_laf, scale)
  116. return laf
  117. def scale_laf(laf: Tensor, scale_coef: Union[float, Tensor]) -> Tensor:
  118. """Multiplies region part of LAF ([:, :, :2, :2]) by a scale_coefficient.
  119. So the center, shape and orientation of the local feature stays the same, but the region area changes.
  120. Args:
  121. laf: :math:`(B, N, 2, 3)`
  122. scale_coef: broadcastable tensor or float.
  123. Returns:
  124. LAF :math:`(B, N, 2, 3)`
  125. Example:
  126. >>> input = torch.ones(1, 5, 2, 3) # BxNx2x3
  127. >>> scale = 0.5
  128. >>> output = scale_laf(input, scale) # BxNx2x3
  129. """
  130. if not isinstance(scale_coef, (float, Tensor)):
  131. raise TypeError(f"scale_coef should be float or Tensor. Got {type(scale_coef)}")
  132. KORNIA_CHECK_LAF(laf)
  133. centerless_laf = laf[:, :, :2, :2]
  134. return concatenate([scale_coef * centerless_laf, laf[:, :, :, 2:]], dim=3)
  135. def make_upright(laf: Tensor, eps: float = 1e-9) -> Tensor:
  136. """Rectify the affine matrix, so that it becomes upright.
  137. Args:
  138. laf: :math:`(B, N, 2, 3)`
  139. eps: for safe division.
  140. Returns:
  141. laf: :math:`(B, N, 2, 3)`
  142. Example:
  143. >>> input = torch.ones(1, 5, 2, 3) # BxNx2x3
  144. >>> output = make_upright(input) # BxNx2x3
  145. """
  146. KORNIA_CHECK_LAF(laf)
  147. det = get_laf_scale(laf)
  148. scale = det
  149. # The function is equivalent to doing 2x2 SVD and resetting rotation
  150. # matrix to an identity: U, S, V = svd(LAF); LAF_upright = U * S.
  151. b2a2 = torch.sqrt(laf[..., 0:1, 1:2] ** 2 + laf[..., 0:1, 0:1] ** 2) + eps
  152. laf1_ell = concatenate([(b2a2 / det).contiguous(), torch.zeros_like(det)], dim=3)
  153. laf2_ell = concatenate(
  154. [
  155. ((laf[..., 1:2, 1:2] * laf[..., 0:1, 1:2] + laf[..., 1:2, 0:1] * laf[..., 0:1, 0:1]) / (b2a2 * det)),
  156. (det / b2a2).contiguous(),
  157. ],
  158. dim=3,
  159. )
  160. laf_unit_scale = concatenate([concatenate([laf1_ell, laf2_ell], dim=2), laf[..., :, 2:3]], dim=3)
  161. return scale_laf(laf_unit_scale, scale)
  162. def ellipse_to_laf(ells: Tensor) -> Tensor:
  163. """Convert ellipse regions to LAF format.
  164. Ellipse (a, b, c) and upright covariance matrix [a11 a12; 0 a22] are connected
  165. by inverse matrix square root: A = invsqrt([a b; b c]).
  166. See also https://github.com/vlfeat/vlfeat/blob/master/toolbox/sift/vl_frame2oell.m
  167. Args:
  168. ells: tensor :math:`(B, N, 5)` of ellipses in Oxford format [x y a b c].
  169. Returns:
  170. LAF :math:`(B, N, 2, 3)`
  171. Example:
  172. >>> input = torch.ones(1, 10, 5) # BxNx5
  173. >>> output = ellipse_to_laf(input) # BxNx2x3
  174. """
  175. KORNIA_CHECK_SHAPE(ells, ["B", "N", "5"])
  176. B, N, _ = ells.shape
  177. # Previous implementation was incorrectly using Cholesky decomp as matrix sqrt
  178. # ell_shape = concatenate([concatenate([ells[..., 2:3], ells[..., 3:4]], dim=2).unsqueeze(2),
  179. # concatenate([ells[..., 3:4], ells[..., 4:5]], dim=2).unsqueeze(2)], dim=2).view(-1, 2, 2)
  180. # out = torch.matrix_power(torch.cholesky(ell_shape, False), -1).view(B, N, 2, 2)
  181. # We will calculate 2x2 matrix square root via special case formula
  182. # https://en.wikipedia.org/wiki/Square_root_of_a_matrix
  183. # "The Cholesky factorization provides another particular example of square root
  184. # which should not be confused with the unique non-negative square root."
  185. # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
  186. # M = (A 0; C D)
  187. # R = (sqrt(A) 0; C / (sqrt(A)+sqrt(D)) sqrt(D))
  188. a11 = ells[..., 2:3].abs().sqrt()
  189. a12 = torch.zeros_like(a11)
  190. a22 = ells[..., 4:5].abs().sqrt()
  191. a21 = ells[..., 3:4] / (a11 + a22).clamp(1e-9)
  192. A = stack([a11, a12, a21, a22], dim=-1).view(B, N, 2, 2).inverse()
  193. out = concatenate([A, ells[..., :2].view(B, N, 2, 1)], dim=3)
  194. return out
  195. def laf_to_boundary_points(LAF: Tensor, n_pts: int = 50) -> Tensor:
  196. """Convert LAFs to boundary points of the regions + center.
  197. Used for local features visualization, see visualize_laf function.
  198. Args:
  199. LAF: :math:`(B, N, 2, 3)`
  200. n_pts: number of points to output.
  201. Returns:
  202. tensor of boundary points LAF: :math:`(B, N, n_pts, 2)`
  203. """
  204. KORNIA_CHECK_LAF(LAF)
  205. B, N, _, _ = LAF.size()
  206. pts = concatenate(
  207. [
  208. sin(torch.linspace(0, 2 * math.pi, n_pts - 1)).unsqueeze(-1),
  209. cos(torch.linspace(0, 2 * math.pi, n_pts - 1)).unsqueeze(-1),
  210. torch.ones(n_pts - 1, 1),
  211. ],
  212. dim=1,
  213. )
  214. # Add origin to draw also the orientation
  215. pts = concatenate([tensor([0.0, 0.0, 1.0]).view(1, 3), pts], dim=0).unsqueeze(0).expand(B * N, n_pts, 3)
  216. pts = pts.to(LAF.device).to(LAF.dtype)
  217. aux = tensor([0.0, 0.0, 1.0]).view(1, 1, 3).expand(B * N, 1, 3)
  218. HLAF = concatenate([LAF.view(-1, 2, 3), aux.to(LAF.device).to(LAF.dtype)], dim=1)
  219. pts_h = torch.bmm(HLAF, pts.permute(0, 2, 1)).permute(0, 2, 1)
  220. return convert_points_from_homogeneous(pts_h.view(B, N, n_pts, 3))
  221. def get_laf_pts_to_draw(LAF: Tensor, img_idx: int = 0) -> Tuple[List[int], List[int]]:
  222. """Return list for drawing LAFs (local features).
  223. Args:
  224. LAF: :math:`(B, N, 2, 3)`
  225. img_idx: which points to output.
  226. Returns:
  227. List of boundary points x, y`
  228. Examples:
  229. x, y = get_laf_pts_to_draw(LAF, img_idx)
  230. plt.figure()
  231. plt.imshow(kornia.utils.tensor_to_image(img[img_idx]))
  232. plt.plot(x, y, 'r')
  233. plt.show()
  234. """
  235. # TODO: Refactor doctest
  236. KORNIA_CHECK_LAF(LAF)
  237. pts = laf_to_boundary_points(LAF[img_idx : img_idx + 1])[0]
  238. pts_np = pts.detach().permute(1, 0, 2).cpu()
  239. return (pts_np[..., 0].tolist(), pts_np[..., 1].tolist())
  240. def denormalize_laf(LAF: Tensor, images: Tensor) -> Tensor:
  241. """De-normalize LAFs from scale to image scale.
  242. The convention is that center of 5-pixel image (coordinates from 0 to 4) is 2, and not 2.5.
  243. B,N,H,W = images.size()
  244. MIN_SIZE = min(H - 1, W -1)
  245. [a11 a21 x]
  246. [a21 a22 y]
  247. becomes
  248. [a11*MIN_SIZE a21*MIN_SIZE x*(W-1)]
  249. [a21*MIN_SIZE a22*MIN_SIZE y*(W-1)]
  250. Args:
  251. LAF: :math:`(B, N, 2, 3)`
  252. images: :math:`(B, CH, H, W)`
  253. Returns:
  254. the denormalized LAF: :math:`(B, N, 2, 3)`, scale in pixels
  255. """
  256. KORNIA_CHECK_LAF(LAF)
  257. _, _, h, w = images.size()
  258. wf = float(w - 1)
  259. hf = float(h - 1)
  260. min_size = min(hf, wf)
  261. coef = torch.ones(1, 1, 2, 3, dtype=LAF.dtype, device=LAF.device) * min_size
  262. coef[0, 0, 0, 2] = wf
  263. coef[0, 0, 1, 2] = hf
  264. return coef.expand_as(LAF) * LAF
  265. def normalize_laf(LAF: Tensor, images: Tensor) -> Tensor:
  266. """Normalize LAFs to [0,1] scale from pixel scale.
  267. See below:
  268. B,N,H,W = images.size()
  269. MIN_SIZE = min(H - 1, W -1)
  270. [a11 a21 x]
  271. [a21 a22 y]
  272. becomes:
  273. [a11/MIN_SIZE a21/MIN_SIZE x/(W-1)]
  274. [a21/MIN_SIZE a22/MIN_SIZE y/(H-1)]
  275. Args:
  276. LAF: :math:`(B, N, 2, 3)`
  277. images: :math:`(B, CH, H, W)`
  278. Returns:
  279. the denormalized LAF: :math:`(B, N, 2, 3)`, scale in image percentage (0, 1)
  280. """
  281. KORNIA_CHECK_LAF(LAF)
  282. _, _, h, w = images.size()
  283. wf = float(w - 1)
  284. hf = float(h - 1)
  285. min_size = min(hf, wf)
  286. coef = torch.ones(1, 1, 2, 3, dtype=LAF.dtype, device=LAF.device) / min_size
  287. coef[0, 0, 0, 2] = 1.0 / wf
  288. coef[0, 0, 1, 2] = 1.0 / hf
  289. return coef.expand_as(LAF) * LAF
  290. def generate_patch_grid_from_normalized_LAF(img: Tensor, LAF: Tensor, PS: int = 32) -> Tensor:
  291. """Generate affine grid.
  292. Args:
  293. img: image tensor of shape :math:`(B, CH, H, W)`.
  294. LAF: laf with shape :math:`(B, N, 2, 3)`.
  295. PS: patch size to be extracted.
  296. Returns:
  297. grid :math:`(B*N, PS, PS, 2)`
  298. """
  299. KORNIA_CHECK_LAF(LAF)
  300. B, N, _, _ = LAF.size()
  301. _, ch, h, w = img.size()
  302. # norm, then renorm is needed for allowing detection on one resolution
  303. # and extraction at arbitrary other
  304. LAF_renorm = denormalize_laf(LAF, img)
  305. grid = F.affine_grid(LAF_renorm.view(B * N, 2, 3), [B * N, ch, PS, PS], align_corners=False)
  306. grid[..., :, 0] = 2.0 * grid[..., :, 0].clone() / float(w - 1) - 1.0
  307. grid[..., :, 1] = 2.0 * grid[..., :, 1].clone() / float(h - 1) - 1.0
  308. return grid
  309. def extract_patches_simple(
  310. img: Tensor, laf: Tensor, PS: int = 32, normalize_lafs_before_extraction: bool = True
  311. ) -> Tensor:
  312. """Extract patches defined by LAFs from image tensor.
  313. No smoothing applied, huge aliasing (better use extract_patches_from_pyramid).
  314. Args:
  315. img: images, LAFs are detected in :math:`(B, CH, H, W)`.
  316. laf: :math:`(B, N, 2, 3)`.
  317. PS: patch size.
  318. normalize_lafs_before_extraction: if True, lafs are normalized to image size.
  319. Returns:
  320. patches with shape :math:`(B, N, CH, PS,PS)`.
  321. """
  322. KORNIA_CHECK_LAF(laf)
  323. if normalize_lafs_before_extraction:
  324. nlaf = normalize_laf(laf, img)
  325. else:
  326. nlaf = laf
  327. _, ch, h, w = img.size()
  328. B, N, _, _ = laf.size()
  329. out = []
  330. # for loop temporarily, to be refactored
  331. for i in range(B):
  332. grid = generate_patch_grid_from_normalized_LAF(img[i : i + 1], nlaf[i : i + 1], PS).to(img.device)
  333. out.append(
  334. F.grid_sample(
  335. img[i : i + 1].expand(grid.size(0), ch, h, w), grid, padding_mode="border", align_corners=False
  336. )
  337. )
  338. return concatenate(out, dim=0).view(B, N, ch, PS, PS)
  339. def extract_patches_from_pyramid(
  340. img: Tensor, laf: Tensor, PS: int = 32, normalize_lafs_before_extraction: bool = True
  341. ) -> Tensor:
  342. """Extract patches defined by LAFs from image tensor.
  343. Patches are extracted from appropriate pyramid level.
  344. Args:
  345. img: images, LAFs are detected in :math:`(B, CH, H, W)`.
  346. laf: :math:`(B, N, 2, 3)`.
  347. PS: patch size.
  348. normalize_lafs_before_extraction: if True, lafs are normalized to image size.
  349. Returns:
  350. patches with shape :math:`(B, N, CH, PS,PS)`.
  351. """
  352. KORNIA_CHECK_LAF(laf)
  353. if normalize_lafs_before_extraction:
  354. nlaf = normalize_laf(laf, img)
  355. else:
  356. nlaf = laf
  357. B, N, _, _ = laf.size()
  358. _, ch, h, w = img.size()
  359. scale = 2.0 * get_laf_scale(denormalize_laf(nlaf, img)) / float(PS)
  360. max_level = min(img.size(2), img.size(3)) // PS
  361. pyr_idx = scale.log2().clamp(min=0.0, max=max(0, max_level - 1)).long()
  362. cur_img = img
  363. cur_pyr_level = 0
  364. out = torch.zeros(B, N, ch, PS, PS).to(nlaf.dtype).to(nlaf.device)
  365. we_are_in_business = True
  366. while we_are_in_business:
  367. _, ch, h, w = cur_img.size()
  368. # for loop temporarily, to be refactored
  369. for i in range(B):
  370. scale_mask = (pyr_idx[i] == cur_pyr_level).squeeze()
  371. if (scale_mask.float().sum().item()) == 0:
  372. continue
  373. scale_mask = (scale_mask > 0).view(-1)
  374. grid = generate_patch_grid_from_normalized_LAF(cur_img[i : i + 1], nlaf[i : i + 1, scale_mask, :, :], PS)
  375. patches = F.grid_sample(
  376. cur_img[i : i + 1].expand(grid.shape[0], ch, h, w), grid, padding_mode="border", align_corners=False
  377. )
  378. out[i].masked_scatter_(scale_mask.view(-1, 1, 1, 1), patches.to(nlaf.dtype))
  379. we_are_in_business = min(cur_img.size(2), cur_img.size(3)) >= PS
  380. if not we_are_in_business:
  381. break
  382. cur_img = pyrdown(cur_img)
  383. cur_pyr_level += 1
  384. return out
  385. def laf_is_inside_image(laf: Tensor, images: Tensor, border: int = 0) -> Tensor:
  386. """Check if the LAF is touching or partly outside the image boundary.
  387. Returns the mask of LAFs, which are fully inside the image, i.e. valid.
  388. Args:
  389. laf: :math:`(B, N, 2, 3)`.
  390. images: images, lafs are detected in :math:`(B, CH, H, W)`.
  391. border: additional border.
  392. Returns:
  393. mask with shape :math:`(B, N)`.
  394. """
  395. KORNIA_CHECK_LAF(laf)
  396. _, _, h, w = images.size()
  397. pts = laf_to_boundary_points(laf, 12)
  398. good_lafs_mask = (
  399. (pts[..., 0] >= border) * (pts[..., 0] <= w - border) * (pts[..., 1] >= border) * (pts[..., 1] <= h - border)
  400. )
  401. good_lafs_mask = good_lafs_mask.min(dim=2)[0]
  402. return good_lafs_mask
  403. def laf_to_three_points(laf: Tensor) -> Tensor:
  404. """Convert local affine frame(LAF) to alternative representation: coordinates of LAF center, LAF-x unit vector,
  405. LAF-y unit vector.
  406. Args:
  407. laf: :math:`(B, N, 2, 3)`.
  408. Returns:
  409. threepts :math:`(B, N, 2, 3)`.
  410. """ # noqa:D205
  411. KORNIA_CHECK_LAF(laf)
  412. three_pts = stack([laf[..., 2] + laf[..., 0], laf[..., 2] + laf[..., 1], laf[..., 2]], dim=-1)
  413. return three_pts
  414. def laf_from_three_points(threepts: Tensor) -> Tensor:
  415. """Convert three points to local affine frame.
  416. Order is (0,0), (0, 1), (1, 0).
  417. Args:
  418. threepts: :math:`(B, N, 2, 3)`.
  419. Returns:
  420. laf :math:`(B, N, 2, 3)`.
  421. """
  422. laf = stack([threepts[..., 0] - threepts[..., 2], threepts[..., 1] - threepts[..., 2], threepts[..., 2]], dim=-1)
  423. return laf
  424. def perspective_transform_lafs(trans_01: Tensor, lafs_1: Tensor) -> Tensor:
  425. r"""Apply perspective transformations to a set of local affine frames (LAFs).
  426. Args:
  427. trans_01: tensor for perspective transformations of shape :math:`(B, 3, 3)`.
  428. lafs_1: tensor of lafs of shape :math:`(B, N, 2, 3)`.
  429. Returns:
  430. tensor of N-dimensional points of shape :math:`(B, N, 2, 3)`.
  431. Examples:
  432. >>> rng = torch.manual_seed(0)
  433. >>> lafs_1 = torch.rand(2, 4, 2, 3) # BxNx2x3
  434. >>> lafs_1
  435. tensor([[[[0.4963, 0.7682, 0.0885],
  436. [0.1320, 0.3074, 0.6341]],
  437. <BLANKLINE>
  438. [[0.4901, 0.8964, 0.4556],
  439. [0.6323, 0.3489, 0.4017]],
  440. <BLANKLINE>
  441. [[0.0223, 0.1689, 0.2939],
  442. [0.5185, 0.6977, 0.8000]],
  443. <BLANKLINE>
  444. [[0.1610, 0.2823, 0.6816],
  445. [0.9152, 0.3971, 0.8742]]],
  446. <BLANKLINE>
  447. <BLANKLINE>
  448. [[[0.4194, 0.5529, 0.9527],
  449. [0.0362, 0.1852, 0.3734]],
  450. <BLANKLINE>
  451. [[0.3051, 0.9320, 0.1759],
  452. [0.2698, 0.1507, 0.0317]],
  453. <BLANKLINE>
  454. [[0.2081, 0.9298, 0.7231],
  455. [0.7423, 0.5263, 0.2437]],
  456. <BLANKLINE>
  457. [[0.5846, 0.0332, 0.1387],
  458. [0.2422, 0.8155, 0.7932]]]])
  459. >>> trans_01 = torch.eye(3).repeat(2, 1, 1) # Bx3x3
  460. >>> trans_01.shape
  461. torch.Size([2, 3, 3])
  462. >>> lafs_0 = perspective_transform_lafs(trans_01, lafs_1) # BxNx2x3
  463. """
  464. KORNIA_CHECK_LAF(lafs_1)
  465. if not torch.is_tensor(trans_01):
  466. raise TypeError("Input type is not a Tensor")
  467. if not trans_01.device == lafs_1.device:
  468. raise TypeError("Tensor must be in the same device")
  469. if not trans_01.shape[0] == lafs_1.shape[0]:
  470. raise ValueError("Input batch size must be the same for both tensors")
  471. if (not (trans_01.shape[-1] == 3)) or (not (trans_01.shape[-2] == 3)):
  472. raise ValueError("Transformation should be homography")
  473. bs, n, _, _ = lafs_1.size()
  474. # First, we convert LAF to points
  475. threepts_1 = laf_to_three_points(lafs_1)
  476. points_1 = threepts_1.permute(0, 1, 3, 2).reshape(bs, n * 3, 2)
  477. # First, transform the points
  478. points_0 = transform_points(trans_01, points_1)
  479. # Back to LAF format
  480. threepts_0 = points_0.view(bs, n, 3, 2).permute(0, 1, 3, 2)
  481. return laf_from_three_points(threepts_0)