|
|
@@ -153,7 +153,7 @@ class XFeatModel(nn.Module):
|
|
|
indexing = "xy"),
|
|
|
dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
|
|
|
cv = corr_volume
|
|
|
- best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W
|
|
|
+ best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
|
|
|
P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
|
|
|
pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
|
|
|
pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
|