|
|
@@ -112,7 +112,7 @@ class ConvRefiner(nn.Module):
|
|
|
(
|
|
|
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
|
|
|
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
|
|
|
- )
|
|
|
+ ), indexing='ij'
|
|
|
)
|
|
|
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
|
|
@@ -196,14 +196,14 @@ class GP(nn.Module):
|
|
|
cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
|
|
|
delta = torch.stack(
|
|
|
torch.meshgrid(
|
|
|
- torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
|
|
|
- ),
|
|
|
+ torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1),
|
|
|
+ indexing = 'ij'),
|
|
|
dim=-1,
|
|
|
)
|
|
|
positions = torch.stack(
|
|
|
torch.meshgrid(
|
|
|
- torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
|
|
|
- ),
|
|
|
+ torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2),
|
|
|
+ indexing = 'ij'),
|
|
|
dim=-1,
|
|
|
)
|
|
|
neighbours = positions[:, :, None, None, :] + delta[None, :, :]
|
|
|
@@ -235,7 +235,8 @@ class GP(nn.Module):
|
|
|
(
|
|
|
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
|
|
|
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
|
|
|
- )
|
|
|
+ ),
|
|
|
+ indexing = 'ij'
|
|
|
)
|
|
|
|
|
|
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
|
|
@@ -304,7 +305,8 @@ class Decoder(nn.Module):
|
|
|
(
|
|
|
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
|
|
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
|
|
- )
|
|
|
+ ),
|
|
|
+ indexing = 'ij'
|
|
|
)
|
|
|
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
|
|
None
|
|
|
@@ -317,7 +319,8 @@ class Decoder(nn.Module):
|
|
|
(
|
|
|
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
|
|
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
|
|
- )
|
|
|
+ ),
|
|
|
+ indexing = 'ij'
|
|
|
)
|
|
|
|
|
|
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
|
|
@@ -675,7 +678,8 @@ class RegressionMatcher(nn.Module):
|
|
|
(
|
|
|
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
|
|
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
|
|
- )
|
|
|
+ ),
|
|
|
+ indexing = 'ij'
|
|
|
)
|
|
|
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|