Prechádzať zdrojové kódy

fix: line order in RobustLoss (#44)

lnex 1 rok pred
rodič
commit
2d869bb243
1 zmenil súbory, kde vykonal 2 pridanie a 2 odobranie
  1. 2 2
      roma/losses/robust_loss.py

+ 2 - 2
roma/losses/robust_loss.py

@@ -49,10 +49,10 @@ class RobustLosses(nn.Module):
             G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
             GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
         cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction  = 'none')[prob > 0.99]
+        certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
         if not torch.any(cls_loss):
             cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
-
-        certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
+            
         losses = {
             f"gm_certainty_loss_{scale}": certainty_loss.mean(),
             f"gm_cls_loss_{scale}": cls_loss.mean(),