|
|
@@ -111,8 +111,8 @@ class ConvRefiner(nn.Module):
|
|
|
if self.has_displacement_emb:
|
|
|
im_A_coords = torch.meshgrid(
|
|
|
(
|
|
|
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
|
|
|
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
|
|
|
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
|
|
|
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
|
|
|
)
|
|
|
)
|
|
|
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
@@ -657,8 +657,8 @@ class RegressionMatcher(nn.Module):
|
|
|
# Create im_A meshgrid
|
|
|
im_A_coords = torch.meshgrid(
|
|
|
(
|
|
|
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
|
|
|
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
|
|
|
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
|
|
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
|
|
)
|
|
|
)
|
|
|
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
@@ -714,4 +714,4 @@ class RegressionMatcher(nn.Module):
|
|
|
if save_path is not None:
|
|
|
from roma.utils import tensor_to_pil
|
|
|
tensor_to_pil(vis_im, unnormalize=False).save(save_path)
|
|
|
- return vis_im
|
|
|
+ return vis_im
|