local_correlation.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import torch
  2. import torch.nn.functional as F
  3. def local_correlation(
  4. feature0,
  5. feature1,
  6. local_radius,
  7. padding_mode="zeros",
  8. flow = None,
  9. sample_mode = "bilinear",
  10. ):
  11. r = local_radius
  12. K = (2*r+1)**2
  13. B, c, h, w = feature0.size()
  14. corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype)
  15. if flow is None:
  16. # If flow is None, assume feature0 and feature1 are aligned
  17. coords = torch.meshgrid(
  18. (
  19. torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device),
  20. torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device),
  21. ),
  22. indexing = 'ij'
  23. )
  24. coords = torch.stack((coords[1], coords[0]), dim=-1)[
  25. None
  26. ].expand(B, h, w, 2)
  27. else:
  28. coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
  29. local_window = torch.meshgrid(
  30. (
  31. torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device),
  32. torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device),
  33. ),
  34. indexing = 'ij'
  35. )
  36. local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
  37. None
  38. ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2)
  39. for _ in range(B):
  40. with torch.no_grad():
  41. local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2)
  42. window_feature = F.grid_sample(
  43. feature1[_:_+1], local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, #
  44. )
  45. window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
  46. corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
  47. return corr