|
|
@@ -49,10 +49,10 @@ class RobustLosses(nn.Module):
|
|
|
G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
|
|
|
GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
|
|
|
cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99]
|
|
|
+ certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
|
|
|
if not torch.any(cls_loss):
|
|
|
cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
|
|
|
-
|
|
|
- certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
|
|
|
+
|
|
|
losses = {
|
|
|
f"gm_certainty_loss_{scale}": certainty_loss.mean(),
|
|
|
f"gm_cls_loss_{scale}": cls_loss.mean(),
|