|
|
@@ -151,7 +151,8 @@ def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_w
|
|
|
h,w = resolution
|
|
|
symmetric = True
|
|
|
attenuate_cert = True
|
|
|
+ sample_mode = "threshold_balanced"
|
|
|
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds,
|
|
|
- symmetric = symmetric, attenuate_cert=attenuate_cert, **kwargs).to(device)
|
|
|
+ symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device)
|
|
|
matcher.load_state_dict(weights)
|
|
|
return matcher
|