Sfoglia il codice sorgente

fix annoying meshgrid warnings, hopefully without introducing bugs

Johan Edstedt 1 anno fa
parent
commit
cdea84f2fa

+ 2 - 2
romatch/losses/robust_loss.py

@@ -45,7 +45,7 @@ class RobustLosses(nn.Module):
             B, C, H, W = scale_gm_cls.shape
             device = x2.device
             cls_res = round(math.sqrt(C))
-            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
+            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)], indexing='ij')
             G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
             GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
         cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction  = 'none')[prob > 0.99]
@@ -69,9 +69,9 @@ class RobustLosses(nn.Module):
             G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
             GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
         cls_loss = F.cross_entropy(delta_cls, GT, reduction  = 'none')[prob > 0.99]
+        certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
         if not torch.any(cls_loss):
             cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
-        certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
         losses = {
             f"delta_certainty_loss_{scale}": certainty_loss.mean(),
             f"delta_cls_loss_{scale}": cls_loss.mean(),

+ 13 - 9
romatch/models/matcher.py

@@ -112,7 +112,7 @@ class ConvRefiner(nn.Module):
                 (
                     torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
                     torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
-                )
+                ), indexing='ij'
                 )
                 im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
                 im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
@@ -196,14 +196,14 @@ class GP(nn.Module):
         cov = F.pad(cov, 4 * (K // 2,))  # pad v_q
         delta = torch.stack(
             torch.meshgrid(
-                torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
-            ),
+                torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1),
+                indexing = 'ij'),
             dim=-1,
         )
         positions = torch.stack(
             torch.meshgrid(
-                torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
-            ),
+                torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2),
+                indexing = 'ij'),
             dim=-1,
         )
         neighbours = positions[:, :, None, None, :] + delta[None, :, :]
@@ -235,7 +235,8 @@ class GP(nn.Module):
             (
                 torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
                 torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
-            )
+            ),
+            indexing = 'ij'
         )
 
         coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
@@ -304,7 +305,8 @@ class Decoder(nn.Module):
             (
                 torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
                 torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
-            )
+            ),
+            indexing = 'ij'
         )
         coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
             None
@@ -317,7 +319,8 @@ class Decoder(nn.Module):
             (
                 torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
                 torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
-            )
+            ),
+            indexing = 'ij'
         )
 
         coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
@@ -675,7 +678,8 @@ class RegressionMatcher(nn.Module):
                 (
                     torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
                     torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
-                )
+                ),
+                indexing = 'ij'
             )
             im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
             im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)

+ 6 - 2
romatch/utils/local_correlation.py

@@ -19,7 +19,9 @@ def local_correlation(
                 (
                     torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device),
                     torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device),
-                ))
+                ),
+                indexing = 'ij'
+                )
         coords = torch.stack((coords[1], coords[0]), dim=-1)[
             None
         ].expand(B, h, w, 2)
@@ -29,7 +31,9 @@ def local_correlation(
                 (
                     torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device),
                     torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device),
-                ))
+                ),
+                indexing = 'ij'
+                )
     local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
             None
         ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2)

+ 12 - 4
romatch/utils/utils.py

@@ -286,7 +286,10 @@ def cls_to_flow(cls, deterministic_sampling = True):
     B,C,H,W = cls.shape
     device = cls.device
     res = round(math.sqrt(C))
-    G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
+    G = torch.meshgrid(
+        *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
+        indexing = 'ij'
+        )
     G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
     if deterministic_sampling:
         sampled_cls = cls.max(dim=1).indices
@@ -300,7 +303,10 @@ def cls_to_flow_refine(cls):
     B,C,H,W = cls.shape
     device = cls.device
     res = round(math.sqrt(C))
-    G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
+    G = torch.meshgrid(
+        *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
+        indexing = 'ij'
+        )
     G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
     cls = cls.softmax(dim=1)
     mode = cls.max(dim=1).indices
@@ -326,7 +332,8 @@ def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bili
                     -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
                 )
                 for n in (B, H, W)
-            ]
+            ],
+            indexing = 'ij'
         )
         x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
         mask, x2 = warp_kpts(
@@ -619,7 +626,8 @@ def get_grid(b, h, w, device):
         *[
             torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
             for n in (b, h, w)
-        ]
+        ],
+        indexing = 'ij'
     )
     grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
     return grid