robust_loss_tiny_roma.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from einops.einops import rearrange
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from romatch.utils.utils import get_gt_warp
  6. import wandb
  7. import romatch
  8. import math
  9. # This is slightly different than regular romatch due to significantly worse corresps
  10. # The confidence loss is quite tricky here //Johan
  11. class RobustLosses(nn.Module):
  12. def __init__(
  13. self,
  14. robust=False,
  15. center_coords=False,
  16. scale_normalize=False,
  17. ce_weight=0.01,
  18. local_loss=True,
  19. local_dist=None,
  20. smooth_mask = False,
  21. depth_interpolation_mode = "bilinear",
  22. mask_depth_loss = False,
  23. relative_depth_error_threshold = 0.05,
  24. alpha = 1.,
  25. c = 1e-3,
  26. epe_mask_prob_th = None,
  27. cert_only_on_consistent_depth = False,
  28. ):
  29. super().__init__()
  30. if local_dist is None:
  31. local_dist = {}
  32. self.robust = robust # measured in pixels
  33. self.center_coords = center_coords
  34. self.scale_normalize = scale_normalize
  35. self.ce_weight = ce_weight
  36. self.local_loss = local_loss
  37. self.local_dist = local_dist
  38. self.smooth_mask = smooth_mask
  39. self.depth_interpolation_mode = depth_interpolation_mode
  40. self.mask_depth_loss = mask_depth_loss
  41. self.relative_depth_error_threshold = relative_depth_error_threshold
  42. self.avg_overlap = dict()
  43. self.alpha = alpha
  44. self.c = c
  45. self.epe_mask_prob_th = epe_mask_prob_th
  46. self.cert_only_on_consistent_depth = cert_only_on_consistent_depth
  47. def corr_volume_loss(self, mnn:torch.Tensor, corr_volume:torch.Tensor, scale):
  48. b, h,w, h,w = corr_volume.shape
  49. inv_temp = 10
  50. corr_volume = corr_volume.reshape(-1, h*w, h*w)
  51. nll = -(inv_temp*corr_volume).log_softmax(dim = 1) - (inv_temp*corr_volume).log_softmax(dim = 2)
  52. corr_volume_loss = nll[mnn[:,0], mnn[:,1], mnn[:,2]].mean()
  53. losses = {
  54. f"gm_corr_volume_loss_{scale}": corr_volume_loss.mean(),
  55. }
  56. wandb.log(losses, step = romatch.GLOBAL_STEP)
  57. return losses
  58. def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
  59. epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
  60. if scale in self.local_dist:
  61. prob = prob * (epe < (2 / 512) * (self.local_dist[scale] * scale)).float()
  62. if scale == 1:
  63. pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
  64. wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
  65. if self.epe_mask_prob_th is not None:
  66. # if too far away from gt, certainty should be 0
  67. gt_cert = prob * (epe < scale * self.epe_mask_prob_th)
  68. else:
  69. gt_cert = prob
  70. if self.cert_only_on_consistent_depth:
  71. ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0][prob > 0], gt_cert[prob > 0])
  72. else:
  73. ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], gt_cert)
  74. a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
  75. cs = self.c * scale
  76. x = epe[prob > 0.99]
  77. reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
  78. if not torch.any(reg_loss):
  79. reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
  80. losses = {
  81. f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
  82. f"{mode}_regression_loss_{scale}": reg_loss.mean(),
  83. }
  84. wandb.log(losses, step = romatch.GLOBAL_STEP)
  85. return losses
  86. def forward(self, corresps, batch):
  87. scales = list(corresps.keys())
  88. tot_loss = 0.0
  89. # scale_weights due to differences in scale for regression gradients and classification gradients
  90. for scale in scales:
  91. scale_corresps = corresps[scale]
  92. scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_corr_volume, scale_gm_certainty, flow, scale_gm_flow = (
  93. scale_corresps["certainty"],
  94. scale_corresps.get("flow_pre_delta"),
  95. scale_corresps.get("delta_cls"),
  96. scale_corresps.get("offset_scale"),
  97. scale_corresps.get("corr_volume"),
  98. scale_corresps.get("gm_certainty"),
  99. scale_corresps["flow"],
  100. scale_corresps.get("gm_flow"),
  101. )
  102. if flow_pre_delta is not None:
  103. flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
  104. b, h, w, d = flow_pre_delta.shape
  105. else:
  106. # _ = 1
  107. b, _, h, w = scale_certainty.shape
  108. gt_warp, gt_prob = get_gt_warp(
  109. batch["im_A_depth"],
  110. batch["im_B_depth"],
  111. batch["T_1to2"],
  112. batch["K1"],
  113. batch["K2"],
  114. H=h,
  115. W=w,
  116. )
  117. x2 = gt_warp.float()
  118. prob = gt_prob
  119. if scale_gm_corr_volume is not None:
  120. gt_warp_back, _ = get_gt_warp(
  121. batch["im_B_depth"],
  122. batch["im_A_depth"],
  123. batch["T_1to2"].inverse(),
  124. batch["K2"],
  125. batch["K1"],
  126. H=h,
  127. W=w,
  128. )
  129. grid = torch.stack(torch.meshgrid(torch.linspace(-1+1/w, 1-1/w, w), torch.linspace(-1+1/h, 1-1/h, h), indexing='xy'), dim =-1).to(gt_warp.device)
  130. #fwd_bck = F.grid_sample(gt_warp_back.permute(0,3,1,2), gt_warp, align_corners=False, mode = 'bilinear').permute(0,2,3,1)
  131. #diff = (fwd_bck - grid).norm(dim = -1)
  132. with torch.no_grad():
  133. D_B = torch.cdist(gt_warp.float().reshape(-1,h*w,2), grid.reshape(-1,h*w,2))
  134. D_A = torch.cdist(grid.reshape(-1,h*w,2), gt_warp_back.float().reshape(-1,h*w,2))
  135. inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values)
  136. * (D_A == D_A.min(dim=-2, keepdim = True).values)
  137. * (D_B < 0.01)
  138. * (D_A < 0.01))
  139. gm_cls_losses = self.corr_volume_loss(inds, scale_gm_corr_volume, scale)
  140. gm_loss = gm_cls_losses[f"gm_corr_volume_loss_{scale}"]
  141. tot_loss = tot_loss + gm_loss
  142. elif scale_gm_flow is not None:
  143. gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
  144. gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
  145. tot_loss = tot_loss + gm_loss
  146. delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
  147. reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
  148. tot_loss = tot_loss + reg_loss
  149. return tot_loss