qkqhd222 2 лет назад
Родитель
Сommit
a274747786
1 измененных файлов с 5 добавлено и 5 удалено
  1. 5 5
      roma/utils/local_correlation.py

+ 5 - 5
roma/utils/local_correlation.py

@@ -17,8 +17,8 @@ def local_correlation(
         # If flow is None, assume feature0 and feature1 are aligned
         coords = torch.meshgrid(
                 (
-                    torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
-                    torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
+                    torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device),
+                    torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device),
                 ))
         coords = torch.stack((coords[1], coords[0]), dim=-1)[
             None
@@ -27,8 +27,8 @@ def local_correlation(
         coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
     local_window = torch.meshgrid(
                 (
-                    torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device="cuda"),
-                    torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device="cuda"),
+                    torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device),
+                    torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device),
                 ))
     local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
             None
@@ -41,4 +41,4 @@ def local_correlation(
             )
             window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
         corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
-    return corr
+    return corr