|
|
@@ -17,8 +17,8 @@ def local_correlation(
|
|
|
# If flow is None, assume feature0 and feature1 are aligned
|
|
|
coords = torch.meshgrid(
|
|
|
(
|
|
|
- torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
|
|
|
- torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
|
|
|
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device),
|
|
|
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device),
|
|
|
))
|
|
|
coords = torch.stack((coords[1], coords[0]), dim=-1)[
|
|
|
None
|
|
|
@@ -27,8 +27,8 @@ def local_correlation(
|
|
|
coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
|
|
|
local_window = torch.meshgrid(
|
|
|
(
|
|
|
- torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device="cuda"),
|
|
|
- torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device="cuda"),
|
|
|
+ torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device),
|
|
|
+ torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device),
|
|
|
))
|
|
|
local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
|
|
|
None
|
|
|
@@ -41,4 +41,4 @@ def local_correlation(
|
|
|
)
|
|
|
window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
|
|
|
corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
|
|
|
- return corr
|
|
|
+ return corr
|