瀏覽代碼

Merge pull request #10 from qkqhd222/main

Edit setup.py to be similar to setup.py of DKM repo
Johan Edstedt 2 年之前
父節點
當前提交
69cefb130b
共有 3 個文件被更改,包括 12 次插入12 次删除
  1. 5 5
      roma/models/matcher.py
  2. 5 5
      roma/utils/local_correlation.py
  3. 2 2
      setup.py

+ 5 - 5
roma/models/matcher.py

@@ -111,8 +111,8 @@ class ConvRefiner(nn.Module):
             if self.has_displacement_emb:
                 im_A_coords = torch.meshgrid(
                 (
-                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
-                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
+                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
+                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
                 )
                 )
                 im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
@@ -675,8 +675,8 @@ class RegressionMatcher(nn.Module):
             # Create im_A meshgrid
             im_A_coords = torch.meshgrid(
                 (
-                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
-                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
+                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
+                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
                 )
             )
             im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
@@ -732,4 +732,4 @@ class RegressionMatcher(nn.Module):
         if save_path is not None:
             from roma.utils import tensor_to_pil
             tensor_to_pil(vis_im, unnormalize=False).save(save_path)
-        return vis_im
+        return vis_im

+ 5 - 5
roma/utils/local_correlation.py

@@ -17,8 +17,8 @@ def local_correlation(
         # If flow is None, assume feature0 and feature1 are aligned
         coords = torch.meshgrid(
                 (
-                    torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
-                    torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
+                    torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device),
+                    torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device),
                 ))
         coords = torch.stack((coords[1], coords[0]), dim=-1)[
             None
@@ -27,8 +27,8 @@ def local_correlation(
         coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
     local_window = torch.meshgrid(
                 (
-                    torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device="cuda"),
-                    torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device="cuda"),
+                    torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device),
+                    torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device),
                 ))
     local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
             None
@@ -41,4 +41,4 @@ def local_correlation(
             )
             window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
         corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
-    return corr
+    return corr

+ 2 - 2
setup.py

@@ -1,8 +1,8 @@
-from setuptools import setup
+from setuptools import setup, find_packages
 
 setup(
     name="roma",
-    packages=["roma"],
+    packages=find_packages(include=("roma*",)),
     version="0.0.1",
     author="Johan Edstedt",
     install_requires=open("requirements.txt", "r").read().split("\n"),