local_correlation.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from typing import Literal
  2. import torch
  3. import torch.nn.functional as F
  4. import local_corr
  5. def local_corr_wrapper(
  6. feature0: torch.Tensor,
  7. feature1: torch.Tensor,
  8. coords: torch.Tensor,
  9. local_window: torch.Tensor,
  10. B,
  11. K,
  12. c,
  13. r,
  14. h,
  15. w,
  16. device,
  17. padding_mode="zeros",
  18. sample_mode: Literal["bilinear", "nearest"] = "bilinear",
  19. dtype=torch.float32,
  20. ):
  21. assert padding_mode == "zeros"
  22. warp = (coords[..., None, :] + local_window[:, None, None]).reshape(B, h * w, K, 2)
  23. corr = (
  24. local_corr.local_corr(
  25. feature0.reshape(B, c, h * w).permute(0, 2, 1).float() / (c**0.5),
  26. feature1.permute(0, 2, 3, 1).clone().detach().float(),
  27. warp.clone().detach(),
  28. mode=sample_mode,
  29. normalized_coords=True,
  30. )
  31. .permute(0, 2, 1)
  32. .reshape(B, K, h, w)
  33. )
  34. return corr, warp
  35. def shitty_native_torch_local_corr(
  36. feature0,
  37. feature1,
  38. warp,
  39. local_window,
  40. B,
  41. K,
  42. c,
  43. r,
  44. h,
  45. w,
  46. device,
  47. padding_mode="zeros",
  48. sample_mode="bilinear",
  49. dtype=torch.float32,
  50. ):
  51. corr = torch.empty((B, K, h, w), device=device, dtype=dtype)
  52. for _ in range(B):
  53. with torch.no_grad():
  54. local_window_coords = (
  55. warp[_, :, :, None] + local_window[:, None, None]
  56. ).reshape(1, h, w * K, 2)
  57. window_feature = F.grid_sample(
  58. feature1[_ : _ + 1],
  59. local_window_coords,
  60. padding_mode=padding_mode,
  61. align_corners=False,
  62. mode=sample_mode, #
  63. )
  64. window_feature = window_feature.reshape(c, h, w, K)
  65. corr[_] = (
  66. (feature0[_, ..., None] / (c**0.5) * window_feature)
  67. .sum(dim=0)
  68. .permute(2, 0, 1)
  69. )
  70. return corr, None
  71. def local_correlation(
  72. feature0: torch.Tensor, # (B x C x H x W)
  73. feature1: torch.Tensor, # (B x C x H x W)
  74. local_radius: int,
  75. warp: torch.Tensor, # (B x 2 x H x W)
  76. *,
  77. use_custom_corr: bool,
  78. padding_mode="zeros",
  79. sample_mode: Literal["bilinear", "nearest"] = "bilinear",
  80. ):
  81. r = local_radius
  82. K = (2 * r + 1) ** 2
  83. B, c, h, w = feature0.size()
  84. warp = warp.permute(0, 2, 3, 1)
  85. device = feature0.device
  86. dtype = feature0.dtype
  87. local_window = torch.meshgrid(
  88. [
  89. torch.linspace(
  90. -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=device
  91. ),
  92. torch.linspace(
  93. -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=device
  94. ),
  95. ],
  96. indexing="ij",
  97. )
  98. local_window = (
  99. torch.stack((local_window[1], local_window[0]), dim=-1)[None]
  100. .expand(1, 2 * r + 1, 2 * r + 1, 2)
  101. .reshape(1, K, 2)
  102. )
  103. if not use_custom_corr:
  104. corr, corr_coords = shitty_native_torch_local_corr(
  105. feature0,
  106. feature1,
  107. warp,
  108. local_window,
  109. B,
  110. K,
  111. c,
  112. r,
  113. h,
  114. w,
  115. device,
  116. padding_mode,
  117. sample_mode,
  118. dtype,
  119. )
  120. else:
  121. corr, corr_coords = local_corr_wrapper(
  122. feature0,
  123. feature1,
  124. warp,
  125. local_window,
  126. B,
  127. K,
  128. c,
  129. r,
  130. h,
  131. w,
  132. device,
  133. padding_mode,
  134. sample_mode,
  135. dtype,
  136. )
  137. return corr