robust_loss.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from einops.einops import rearrange
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from romatch.utils.utils import get_gt_warp
  6. import wandb
  7. import romatch
  8. import math
  9. class RobustLosses(nn.Module):
  10. def __init__(
  11. self,
  12. robust=False,
  13. center_coords=False,
  14. scale_normalize=False,
  15. ce_weight=0.01,
  16. local_loss=True,
  17. local_dist=4.0,
  18. local_largest_scale=8,
  19. smooth_mask = False,
  20. depth_interpolation_mode = "bilinear",
  21. mask_depth_loss = False,
  22. relative_depth_error_threshold = 0.05,
  23. alpha = 1.,
  24. c = 1e-3,
  25. ):
  26. super().__init__()
  27. self.robust = robust # measured in pixels
  28. self.center_coords = center_coords
  29. self.scale_normalize = scale_normalize
  30. self.ce_weight = ce_weight
  31. self.local_loss = local_loss
  32. self.local_dist = local_dist
  33. self.local_largest_scale = local_largest_scale
  34. self.smooth_mask = smooth_mask
  35. self.depth_interpolation_mode = depth_interpolation_mode
  36. self.mask_depth_loss = mask_depth_loss
  37. self.relative_depth_error_threshold = relative_depth_error_threshold
  38. self.avg_overlap = dict()
  39. self.alpha = alpha
  40. self.c = c
  41. def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
  42. with torch.no_grad():
  43. B, C, H, W = scale_gm_cls.shape
  44. device = x2.device
  45. cls_res = round(math.sqrt(C))
  46. G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)], indexing='ij')
  47. G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
  48. GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
  49. cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99]
  50. certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
  51. if not torch.any(cls_loss):
  52. cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
  53. losses = {
  54. f"gm_certainty_loss_{scale}": certainty_loss.mean(),
  55. f"gm_cls_loss_{scale}": cls_loss.mean(),
  56. }
  57. wandb.log(losses, step = romatch.GLOBAL_STEP)
  58. return losses
  59. def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
  60. with torch.no_grad():
  61. B, C, H, W = delta_cls.shape
  62. device = x2.device
  63. cls_res = round(math.sqrt(C))
  64. G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
  65. G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
  66. GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
  67. cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99]
  68. certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
  69. if not torch.any(cls_loss):
  70. cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
  71. losses = {
  72. f"delta_certainty_loss_{scale}": certainty_loss.mean(),
  73. f"delta_cls_loss_{scale}": cls_loss.mean(),
  74. }
  75. wandb.log(losses, step = romatch.GLOBAL_STEP)
  76. return losses
  77. def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
  78. epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
  79. if scale == 1:
  80. pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
  81. wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
  82. ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
  83. a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
  84. cs = self.c * scale
  85. x = epe[prob > 0.99]
  86. reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
  87. if not torch.any(reg_loss):
  88. reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
  89. losses = {
  90. f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
  91. f"{mode}_regression_loss_{scale}": reg_loss.mean(),
  92. }
  93. wandb.log(losses, step = romatch.GLOBAL_STEP)
  94. return losses
  95. def forward(self, corresps, batch):
  96. scales = list(corresps.keys())
  97. tot_loss = 0.0
  98. # scale_weights due to differences in scale for regression gradients and classification gradients
  99. scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
  100. for scale in scales:
  101. scale_corresps = corresps[scale]
  102. scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
  103. scale_corresps["certainty"],
  104. scale_corresps.get("flow_pre_delta"),
  105. scale_corresps.get("delta_cls"),
  106. scale_corresps.get("offset_scale"),
  107. scale_corresps.get("gm_cls"),
  108. scale_corresps.get("gm_certainty"),
  109. scale_corresps["flow"],
  110. scale_corresps.get("gm_flow"),
  111. )
  112. if flow_pre_delta is not None:
  113. flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
  114. b, h, w, d = flow_pre_delta.shape
  115. else:
  116. # _ = 1
  117. b, _, h, w = scale_certainty.shape
  118. gt_warp, gt_prob = get_gt_warp(
  119. batch["im_A_depth"],
  120. batch["im_B_depth"],
  121. batch["T_1to2"],
  122. batch["K1"],
  123. batch["K2"],
  124. H=h,
  125. W=w,
  126. )
  127. x2 = gt_warp.float()
  128. prob = gt_prob
  129. if self.local_largest_scale >= scale:
  130. prob = prob * (
  131. F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
  132. < (2 / 512) * (self.local_dist[scale] * scale))
  133. if scale_gm_cls is not None:
  134. gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
  135. gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
  136. tot_loss = tot_loss + scale_weights[scale] * gm_loss
  137. elif scale_gm_flow is not None:
  138. gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
  139. gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
  140. tot_loss = tot_loss + scale_weights[scale] * gm_loss
  141. if delta_cls is not None:
  142. delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
  143. delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
  144. tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
  145. else:
  146. delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
  147. reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
  148. tot_loss = tot_loss + scale_weights[scale] * reg_loss
  149. prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
  150. return tot_loss