瀏覽代碼

fix bug in approx softmax

Johan Edstedt 1 年之前
父節點
當前提交
e6588007d4
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      experiments/train_tiny_roma_v1_outdoor.py

+ 1 - 1
experiments/train_tiny_roma_v1_outdoor.py

@@ -153,7 +153,7 @@ class XFeatModel(nn.Module):
                     indexing = "xy"), 
                 dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
             cv = corr_volume
-            best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W
+            best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
             P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
             pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
             pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)