geometry.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Tuple
  18. import torch
  19. from kornia.core import Tensor
  20. @torch.no_grad()
  21. def warp_kpts(
  22. kpts0: Tensor, depth0: Tensor, depth1: Tensor, T_0to1: Tensor, K0: Tensor, K1: Tensor
  23. ) -> Tuple[Tensor, Tensor]:
  24. """Warp kpts0 from I0 to I1 with depth, K and Rt Also check covisibility and depth consistency.
  25. Depth is consistent if relative error < 0.2 (hard-coded).
  26. Args:
  27. kpts0: [N, L, 2] - <x, y>,
  28. depth0: [N, H, W],
  29. depth1: [N, H, W],
  30. T_0to1: [N, 3, 4],
  31. K0: [N, 3, 3],
  32. K1: [N, 3, 3],
  33. Returns:
  34. calculable_mask: [N, L]
  35. warped_keypoints0: [N, L, 2] <x0_hat, y1_hat>
  36. """
  37. kpts0_long = kpts0.round().long()
  38. # Sample depth, get calculable_mask on depth != 0
  39. kpts0_depth = torch.stack(
  40. [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
  41. ) # (N, L)
  42. nonzero_mask = kpts0_depth != 0
  43. # Unproject
  44. kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
  45. kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
  46. # Rigid Transform
  47. w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
  48. w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
  49. # Project
  50. w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
  51. w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
  52. # Covisible Check
  53. h, w = depth1.shape[1:3]
  54. covisible_mask = (
  55. (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w - 1) * (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h - 1)
  56. )
  57. w_kpts0_long = w_kpts0.long()
  58. w_kpts0_long[~covisible_mask, :] = 0
  59. w_kpts0_depth = torch.stack(
  60. [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
  61. ) # (N, L)
  62. consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
  63. valid_mask = nonzero_mask * covisible_mask * consistent_mask
  64. return valid_mask, w_kpts0