Răsfoiți Sursa

fix so roma_outdoor/indoor uses balanced sampling by default

Johan Edstedt 2 ani în urmă
părinte
comite
d9b2f017ca
2 a modificat fișierele cu 3 adăugiri și 2 ștergeri
  1. 1 1
      roma/models/matcher.py
  2. 2 1
      roma/models/model_zoo/roma_models.py

+ 1 - 1
roma/models/matcher.py

@@ -424,7 +424,7 @@ class RegressionMatcher(nn.Module):
         decoder,
         decoder,
         h=448,
         h=448,
         w=448,
         w=448,
-        sample_mode = "threshold",
+        sample_mode = "threshold_balanced",
         upsample_preds = False,
         upsample_preds = False,
         symmetric = False,
         symmetric = False,
         name = None,
         name = None,

+ 2 - 1
roma/models/model_zoo/roma_models.py

@@ -151,7 +151,8 @@ def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_w
     h,w = resolution
     h,w = resolution
     symmetric = True
     symmetric = True
     attenuate_cert = True
     attenuate_cert = True
+    sample_mode = "threshold_balanced"
     matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, 
     matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, 
-                                symmetric = symmetric, attenuate_cert=attenuate_cert, **kwargs).to(device)
+                                symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device)
     matcher.load_state_dict(weights)
     matcher.load_state_dict(weights)
     return matcher
     return matcher