| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- from typing import Literal
- import torch
- import torch.nn.functional as F
- import local_corr
- def local_corr_wrapper(
- feature0: torch.Tensor,
- feature1: torch.Tensor,
- coords: torch.Tensor,
- local_window: torch.Tensor,
- B,
- K,
- c,
- r,
- h,
- w,
- device,
- padding_mode="zeros",
- sample_mode: Literal["bilinear", "nearest"] = "bilinear",
- dtype=torch.float32,
- ):
- assert padding_mode == "zeros"
- warp = (coords[..., None, :] + local_window[:, None, None]).reshape(B, h * w, K, 2)
- corr = (
- local_corr.local_corr(
- feature0.reshape(B, c, h * w).permute(0, 2, 1).float() / (c**0.5),
- feature1.permute(0, 2, 3, 1).clone().detach().float(),
- warp.clone().detach(),
- mode=sample_mode,
- normalized_coords=True,
- )
- .permute(0, 2, 1)
- .reshape(B, K, h, w)
- )
- return corr, warp
- def shitty_native_torch_local_corr(
- feature0,
- feature1,
- warp,
- local_window,
- B,
- K,
- c,
- r,
- h,
- w,
- device,
- padding_mode="zeros",
- sample_mode="bilinear",
- dtype=torch.float32,
- ):
- corr = torch.empty((B, K, h, w), device=device, dtype=dtype)
- for _ in range(B):
- with torch.no_grad():
- local_window_coords = (
- warp[_, :, :, None] + local_window[:, None, None]
- ).reshape(1, h, w * K, 2)
- window_feature = F.grid_sample(
- feature1[_ : _ + 1],
- local_window_coords,
- padding_mode=padding_mode,
- align_corners=False,
- mode=sample_mode, #
- )
- window_feature = window_feature.reshape(c, h, w, K)
- corr[_] = (
- (feature0[_, ..., None] / (c**0.5) * window_feature)
- .sum(dim=0)
- .permute(2, 0, 1)
- )
- return corr, None
- def local_correlation(
- feature0: torch.Tensor, # (B x C x H x W)
- feature1: torch.Tensor, # (B x C x H x W)
- local_radius: int,
- warp: torch.Tensor, # (B x 2 x H x W)
- *,
- use_custom_corr: bool,
- padding_mode="zeros",
- sample_mode: Literal["bilinear", "nearest"] = "bilinear",
- ):
- r = local_radius
- K = (2 * r + 1) ** 2
- B, c, h, w = feature0.size()
- warp = warp.permute(0, 2, 3, 1)
- device = feature0.device
- dtype = feature0.dtype
- local_window = torch.meshgrid(
- [
- torch.linspace(
- -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=device
- ),
- torch.linspace(
- -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=device
- ),
- ],
- indexing="ij",
- )
- local_window = (
- torch.stack((local_window[1], local_window[0]), dim=-1)[None]
- .expand(1, 2 * r + 1, 2 * r + 1, 2)
- .reshape(1, K, 2)
- )
- if not use_custom_corr:
- corr, corr_coords = shitty_native_torch_local_corr(
- feature0,
- feature1,
- warp,
- local_window,
- B,
- K,
- c,
- r,
- h,
- w,
- device,
- padding_mode,
- sample_mode,
- dtype,
- )
- else:
- corr, corr_coords = local_corr_wrapper(
- feature0,
- feature1,
- warp,
- local_window,
- B,
- K,
- c,
- r,
- h,
- w,
- device,
- padding_mode,
- sample_mode,
- dtype,
- )
- return corr
|