소스 검색

Update matcher.py

qkqhd222 2 년 전
부모
커밋
0263a813cb
1개의 변경된 파일5개의 추가작업 그리고 5개의 파일을 삭제
  1. 5 5
      roma/models/matcher.py

+ 5 - 5
roma/models/matcher.py

@@ -111,8 +111,8 @@ class ConvRefiner(nn.Module):
             if self.has_displacement_emb:
                 im_A_coords = torch.meshgrid(
                 (
-                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
-                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
+                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
+                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
                 )
                 )
                 im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
@@ -657,8 +657,8 @@ class RegressionMatcher(nn.Module):
             # Create im_A meshgrid
             im_A_coords = torch.meshgrid(
                 (
-                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
-                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
+                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
+                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
                 )
             )
             im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
@@ -714,4 +714,4 @@ class RegressionMatcher(nn.Module):
         if save_path is not None:
             from roma.utils import tensor_to_pil
             tensor_to_pil(vis_im, unnormalize=False).save(save_path)
-        return vis_im
+        return vis_im