utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. import warnings
  2. import numpy as np
  3. import cv2
  4. import math
  5. import torch
  6. from torchvision import transforms
  7. from torchvision.transforms.functional import InterpolationMode
  8. import torch.nn.functional as F
  9. from PIL import Image
  10. import kornia
  11. def recover_pose(E, kpts0, kpts1, K0, K1, mask):
  12. best_num_inliers = 0
  13. K0inv = np.linalg.inv(K0[:2,:2])
  14. K1inv = np.linalg.inv(K1[:2,:2])
  15. kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
  16. kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
  17. for _E in np.split(E, len(E) / 3):
  18. n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
  19. if n > best_num_inliers:
  20. best_num_inliers = n
  21. ret = (R, t, mask.ravel() > 0)
  22. return ret
  23. # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
  24. # --- GEOMETRY ---
  25. def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
  26. if len(kpts0) < 5:
  27. return None
  28. K0inv = np.linalg.inv(K0[:2,:2])
  29. K1inv = np.linalg.inv(K1[:2,:2])
  30. kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
  31. kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
  32. E, mask = cv2.findEssentialMat(
  33. kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
  34. )
  35. ret = None
  36. if E is not None:
  37. best_num_inliers = 0
  38. for _E in np.split(E, len(E) / 3):
  39. n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
  40. if n > best_num_inliers:
  41. best_num_inliers = n
  42. ret = (R, t, mask.ravel() > 0)
  43. return ret
  44. def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
  45. if len(kpts0) < 5:
  46. return None
  47. method = cv2.USAC_ACCURATE
  48. F, mask = cv2.findFundamentalMat(
  49. kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
  50. )
  51. E = K1.T@F@K0
  52. ret = None
  53. if E is not None:
  54. best_num_inliers = 0
  55. K0inv = np.linalg.inv(K0[:2,:2])
  56. K1inv = np.linalg.inv(K1[:2,:2])
  57. kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
  58. kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
  59. for _E in np.split(E, len(E) / 3):
  60. n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
  61. if n > best_num_inliers:
  62. best_num_inliers = n
  63. ret = (R, t, mask.ravel() > 0)
  64. return ret
  65. def unnormalize_coords(x_n,h,w):
  66. x = torch.stack(
  67. (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
  68. ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
  69. return x
  70. def rotate_intrinsic(K, n):
  71. base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
  72. rot = np.linalg.matrix_power(base_rot, n)
  73. return rot @ K
  74. def rotate_pose_inplane(i_T_w, rot):
  75. rotation_matrices = [
  76. np.array(
  77. [
  78. [np.cos(r), -np.sin(r), 0.0, 0.0],
  79. [np.sin(r), np.cos(r), 0.0, 0.0],
  80. [0.0, 0.0, 1.0, 0.0],
  81. [0.0, 0.0, 0.0, 1.0],
  82. ],
  83. dtype=np.float32,
  84. )
  85. for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
  86. ]
  87. return np.dot(rotation_matrices[rot], i_T_w)
  88. def scale_intrinsics(K, scales):
  89. scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
  90. return np.dot(scales, K)
  91. def to_homogeneous(points):
  92. return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
  93. def angle_error_mat(R1, R2):
  94. cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
  95. cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
  96. return np.rad2deg(np.abs(np.arccos(cos)))
  97. def angle_error_vec(v1, v2):
  98. n = np.linalg.norm(v1) * np.linalg.norm(v2)
  99. return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
  100. def compute_pose_error(T_0to1, R, t):
  101. R_gt = T_0to1[:3, :3]
  102. t_gt = T_0to1[:3, 3]
  103. error_t = angle_error_vec(t.squeeze(), t_gt)
  104. error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
  105. error_R = angle_error_mat(R, R_gt)
  106. return error_t, error_R
  107. def pose_auc(errors, thresholds):
  108. sort_idx = np.argsort(errors)
  109. errors = np.array(errors.copy())[sort_idx]
  110. recall = (np.arange(len(errors)) + 1) / len(errors)
  111. errors = np.r_[0.0, errors]
  112. recall = np.r_[0.0, recall]
  113. aucs = []
  114. for t in thresholds:
  115. last_index = np.searchsorted(errors, t)
  116. r = np.r_[recall[:last_index], recall[last_index - 1]]
  117. e = np.r_[errors[:last_index], t]
  118. aucs.append(np.trapz(r, x=e) / t)
  119. return aucs
  120. # From Patch2Pix https://github.com/GrumpyZhou/patch2pix
  121. def get_depth_tuple_transform_ops_nearest_exact(resize=None):
  122. ops = []
  123. if resize:
  124. ops.append(TupleResizeNearestExact(resize))
  125. return TupleCompose(ops)
  126. def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
  127. ops = []
  128. if resize:
  129. ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
  130. return TupleCompose(ops)
  131. def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
  132. ops = []
  133. if resize:
  134. ops.append(TupleResize(resize))
  135. ops.append(TupleToTensorScaled())
  136. if normalize:
  137. ops.append(
  138. TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  139. ) # Imagenet mean/std
  140. return TupleCompose(ops)
  141. class ToTensorScaled(object):
  142. """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
  143. def __call__(self, im):
  144. if not isinstance(im, torch.Tensor):
  145. im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
  146. im /= 255.0
  147. return torch.from_numpy(im)
  148. else:
  149. return im
  150. def __repr__(self):
  151. return "ToTensorScaled(./255)"
  152. class TupleToTensorScaled(object):
  153. def __init__(self):
  154. self.to_tensor = ToTensorScaled()
  155. def __call__(self, im_tuple):
  156. return [self.to_tensor(im) for im in im_tuple]
  157. def __repr__(self):
  158. return "TupleToTensorScaled(./255)"
  159. class ToTensorUnscaled(object):
  160. """Convert a RGB PIL Image to a CHW ordered Tensor"""
  161. def __call__(self, im):
  162. return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
  163. def __repr__(self):
  164. return "ToTensorUnscaled()"
  165. class TupleToTensorUnscaled(object):
  166. """Convert a RGB PIL Image to a CHW ordered Tensor"""
  167. def __init__(self):
  168. self.to_tensor = ToTensorUnscaled()
  169. def __call__(self, im_tuple):
  170. return [self.to_tensor(im) for im in im_tuple]
  171. def __repr__(self):
  172. return "TupleToTensorUnscaled()"
  173. class TupleResizeNearestExact:
  174. def __init__(self, size):
  175. self.size = size
  176. def __call__(self, im_tuple):
  177. return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
  178. def __repr__(self):
  179. return "TupleResizeNearestExact(size={})".format(self.size)
  180. class TupleResize(object):
  181. def __init__(self, size, mode=InterpolationMode.BICUBIC):
  182. self.size = size
  183. self.resize = transforms.Resize(size, mode)
  184. def __call__(self, im_tuple):
  185. return [self.resize(im) for im in im_tuple]
  186. def __repr__(self):
  187. return "TupleResize(size={})".format(self.size)
  188. class Normalize:
  189. def __call__(self,im):
  190. mean = im.mean(dim=(1,2), keepdims=True)
  191. std = im.std(dim=(1,2), keepdims=True)
  192. return (im-mean)/std
  193. class TupleNormalize(object):
  194. def __init__(self, mean, std):
  195. self.mean = mean
  196. self.std = std
  197. self.normalize = transforms.Normalize(mean=mean, std=std)
  198. def __call__(self, im_tuple):
  199. c,h,w = im_tuple[0].shape
  200. if c > 3:
  201. warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
  202. return [self.normalize(im[:3]) for im in im_tuple]
  203. def __repr__(self):
  204. return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
  205. class TupleCompose(object):
  206. def __init__(self, transforms):
  207. self.transforms = transforms
  208. def __call__(self, im_tuple):
  209. for t in self.transforms:
  210. im_tuple = t(im_tuple)
  211. return im_tuple
  212. def __repr__(self):
  213. format_string = self.__class__.__name__ + "("
  214. for t in self.transforms:
  215. format_string += "\n"
  216. format_string += " {0}".format(t)
  217. format_string += "\n)"
  218. return format_string
  219. @torch.no_grad()
  220. def cls_to_flow(cls, deterministic_sampling = True):
  221. B,C,H,W = cls.shape
  222. device = cls.device
  223. res = round(math.sqrt(C))
  224. G = torch.meshgrid(
  225. *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
  226. indexing = 'ij'
  227. )
  228. G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
  229. if deterministic_sampling:
  230. sampled_cls = cls.max(dim=1).indices
  231. else:
  232. sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
  233. flow = G[sampled_cls]
  234. return flow
  235. @torch.no_grad()
  236. def cls_to_flow_refine(cls):
  237. B,C,H,W = cls.shape
  238. device = cls.device
  239. res = round(math.sqrt(C))
  240. G = torch.meshgrid(
  241. *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
  242. indexing = 'ij'
  243. )
  244. G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
  245. # FIXME: below softmax line causes mps to bug, don't know why.
  246. if device.type == 'mps':
  247. cls = cls.log_softmax(dim=1).exp()
  248. else:
  249. cls = cls.softmax(dim=1)
  250. mode = cls.max(dim=1).indices
  251. index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
  252. neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
  253. flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
  254. tot_prob = neighbours.sum(dim=1)
  255. flow = flow / tot_prob
  256. return flow
  257. def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
  258. if H is None:
  259. B,H,W = depth1.shape
  260. else:
  261. B = depth1.shape[0]
  262. with torch.no_grad():
  263. x1_n = torch.meshgrid(
  264. *[
  265. torch.linspace(
  266. -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
  267. )
  268. for n in (B, H, W)
  269. ],
  270. indexing = 'ij'
  271. )
  272. x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
  273. mask, x2 = warp_kpts(
  274. x1_n.double(),
  275. depth1.double(),
  276. depth2.double(),
  277. T_1to2.double(),
  278. K1.double(),
  279. K2.double(),
  280. depth_interpolation_mode = depth_interpolation_mode,
  281. relative_depth_error_threshold = relative_depth_error_threshold,
  282. )
  283. prob = mask.float().reshape(B, H, W)
  284. x2 = x2.reshape(B, H, W, 2)
  285. return x2, prob
  286. @torch.no_grad()
  287. def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
  288. """Warp kpts0 from I0 to I1 with depth, K and Rt
  289. Also check covisibility and depth consistency.
  290. Depth is consistent if relative error < 0.2 (hard-coded).
  291. # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
  292. Args:
  293. kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
  294. depth0 (torch.Tensor): [N, H, W],
  295. depth1 (torch.Tensor): [N, H, W],
  296. T_0to1 (torch.Tensor): [N, 3, 4],
  297. K0 (torch.Tensor): [N, 3, 3],
  298. K1 (torch.Tensor): [N, 3, 3],
  299. Returns:
  300. calculable_mask (torch.Tensor): [N, L]
  301. warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
  302. """
  303. (
  304. n,
  305. h,
  306. w,
  307. ) = depth0.shape
  308. if depth_interpolation_mode == "combined":
  309. # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
  310. if smooth_mask:
  311. raise NotImplementedError("Combined bilinear and NN warp not implemented")
  312. valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
  313. smooth_mask = smooth_mask,
  314. return_relative_depth_error = return_relative_depth_error,
  315. depth_interpolation_mode = "bilinear",
  316. relative_depth_error_threshold = relative_depth_error_threshold)
  317. valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
  318. smooth_mask = smooth_mask,
  319. return_relative_depth_error = return_relative_depth_error,
  320. depth_interpolation_mode = "nearest-exact",
  321. relative_depth_error_threshold = relative_depth_error_threshold)
  322. nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
  323. warp = warp_bilinear.clone()
  324. warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
  325. valid = valid_bilinear | valid_nearest
  326. return valid, warp
  327. kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
  328. :, 0, :, 0
  329. ]
  330. kpts0 = torch.stack(
  331. (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
  332. ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
  333. # Sample depth, get calculable_mask on depth != 0
  334. nonzero_mask = kpts0_depth != 0
  335. # Unproject
  336. kpts0_h = (
  337. torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
  338. * kpts0_depth[..., None]
  339. ) # (N, L, 3)
  340. kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
  341. kpts0_cam = kpts0_n
  342. # Rigid Transform
  343. w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
  344. w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
  345. # Project
  346. w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
  347. w_kpts0 = w_kpts0_h[:, :, :2] / (
  348. w_kpts0_h[:, :, [2]] + 1e-4
  349. ) # (N, L, 2), +1e-4 to avoid zero depth
  350. # Covisible Check
  351. h, w = depth1.shape[1:3]
  352. covisible_mask = (
  353. (w_kpts0[:, :, 0] > 0)
  354. * (w_kpts0[:, :, 0] < w - 1)
  355. * (w_kpts0[:, :, 1] > 0)
  356. * (w_kpts0[:, :, 1] < h - 1)
  357. )
  358. w_kpts0 = torch.stack(
  359. (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
  360. ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
  361. # w_kpts0[~covisible_mask, :] = -5 # xd
  362. w_kpts0_depth = F.grid_sample(
  363. depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
  364. )[:, 0, :, 0]
  365. relative_depth_error = (
  366. (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
  367. ).abs()
  368. if not smooth_mask:
  369. consistent_mask = relative_depth_error < relative_depth_error_threshold
  370. else:
  371. consistent_mask = (-relative_depth_error/smooth_mask).exp()
  372. valid_mask = nonzero_mask * covisible_mask * consistent_mask
  373. if return_relative_depth_error:
  374. return relative_depth_error, w_kpts0
  375. else:
  376. return valid_mask, w_kpts0
  377. imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
  378. imagenet_std = torch.tensor([0.229, 0.224, 0.225])
  379. def numpy_to_pil(x: np.ndarray):
  380. """
  381. Args:
  382. x: Assumed to be of shape (h,w,c)
  383. """
  384. if isinstance(x, torch.Tensor):
  385. x = x.detach().cpu().numpy()
  386. if x.max() <= 1.01:
  387. x *= 255
  388. x = x.astype(np.uint8)
  389. return Image.fromarray(x)
  390. def tensor_to_pil(x, unnormalize=False):
  391. if unnormalize:
  392. x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
  393. x = x.detach().permute(1, 2, 0).cpu().numpy()
  394. x = np.clip(x, 0.0, 1.0)
  395. return numpy_to_pil(x)
  396. def to_cuda(batch):
  397. for key, value in batch.items():
  398. if isinstance(value, torch.Tensor):
  399. batch[key] = value.cuda()
  400. return batch
  401. def to_cpu(batch):
  402. for key, value in batch.items():
  403. if isinstance(value, torch.Tensor):
  404. batch[key] = value.cpu()
  405. return batch
  406. def get_pose(calib):
  407. w, h = np.array(calib["imsize"])[0]
  408. return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
  409. def compute_relative_pose(R1, t1, R2, t2):
  410. rots = R2 @ (R1.T)
  411. trans = -rots @ t1 + t2
  412. return rots, trans
  413. @torch.no_grad()
  414. def reset_opt(opt):
  415. for group in opt.param_groups:
  416. for p in group['params']:
  417. if p.requires_grad:
  418. state = opt.state[p]
  419. # State initialization
  420. # Exponential moving average of gradient values
  421. state['exp_avg'] = torch.zeros_like(p)
  422. # Exponential moving average of squared gradient values
  423. state['exp_avg_sq'] = torch.zeros_like(p)
  424. # Exponential moving average of gradient difference
  425. state['exp_avg_diff'] = torch.zeros_like(p)
  426. def flow_to_pixel_coords(flow, h1, w1):
  427. flow = (
  428. torch.stack(
  429. (
  430. w1 * (flow[..., 0] + 1) / 2,
  431. h1 * (flow[..., 1] + 1) / 2,
  432. ),
  433. axis=-1,
  434. )
  435. )
  436. return flow
  437. to_pixel_coords = flow_to_pixel_coords # just an alias
  438. def flow_to_normalized_coords(flow, h1, w1):
  439. flow = (
  440. torch.stack(
  441. (
  442. 2 * (flow[..., 0]) / w1 - 1,
  443. 2 * (flow[..., 1]) / h1 - 1,
  444. ),
  445. axis=-1,
  446. )
  447. )
  448. return flow
  449. to_normalized_coords = flow_to_normalized_coords # just an alias
  450. def warp_to_pixel_coords(warp, h1, w1, h2, w2):
  451. warp1 = warp[..., :2]
  452. warp1 = (
  453. torch.stack(
  454. (
  455. w1 * (warp1[..., 0] + 1) / 2,
  456. h1 * (warp1[..., 1] + 1) / 2,
  457. ),
  458. axis=-1,
  459. )
  460. )
  461. warp2 = warp[..., 2:]
  462. warp2 = (
  463. torch.stack(
  464. (
  465. w2 * (warp2[..., 0] + 1) / 2,
  466. h2 * (warp2[..., 1] + 1) / 2,
  467. ),
  468. axis=-1,
  469. )
  470. )
  471. return torch.cat((warp1,warp2), dim=-1)
  472. def signed_point_line_distance(point, line, eps: float = 1e-9):
  473. r"""Return the distance from points to lines.
  474. Args:
  475. point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`.
  476. line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`.
  477. eps: Small constant for safe sqrt.
  478. Returns:
  479. the computed distance with shape :math:`(*, N)`.
  480. """
  481. if not point.shape[-1] in (2, 3):
  482. raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}")
  483. if not line.shape[-1] == 3:
  484. raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
  485. numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
  486. denominator = line[..., :2].norm(dim=-1)
  487. return numerator / (denominator + eps)
  488. def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
  489. r"""Return one-sided epipolar distance for correspondences given the fundamental matrix.
  490. This method measures the distance from points in the right images to the epilines
  491. of the corresponding points in the left images as they reflect in the right images.
  492. Args:
  493. pts1: correspondences from the left images with shape
  494. :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
  495. pts2: correspondences from the right images with shape
  496. :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
  497. Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to
  498. avoid ambiguity with torch.nn.functional.
  499. Returns:
  500. the computed Symmetrical distance with shape :math:`(*, N)`.
  501. """
  502. import kornia
  503. if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
  504. raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
  505. if pts1.shape[-1] == 2:
  506. pts1 = kornia.geometry.convert_points_to_homogeneous(pts1)
  507. F_t = Fm.transpose(dim0=-2, dim1=-1)
  508. line1_in_2 = pts1 @ F_t
  509. return signed_point_line_distance(pts2, line1_in_2)
  510. def get_grid(b, h, w, device):
  511. grid = torch.meshgrid(
  512. *[
  513. torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
  514. for n in (b, h, w)
  515. ],
  516. indexing = 'ij'
  517. )
  518. grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
  519. return grid
  520. def get_autocast_params(device=None, enabled=False, dtype=None):
  521. if device is None:
  522. autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
  523. else:
  524. #strip :X from device
  525. autocast_device = str(device).split(":")[0]
  526. if 'cuda' in str(device):
  527. out_dtype = dtype
  528. enabled = True
  529. else:
  530. out_dtype = torch.bfloat16
  531. enabled = False
  532. # mps is not supported
  533. autocast_device = "cpu"
  534. return autocast_device, enabled, out_dtype