utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. import warnings
  2. import numpy as np
  3. import cv2
  4. import math
  5. import torch
  6. from torchvision import transforms
  7. from torchvision.transforms.functional import InterpolationMode
  8. import torch.nn.functional as F
  9. from PIL import Image
  10. import kornia
  11. def recover_pose(E, kpts0, kpts1, K0, K1, mask):
  12. best_num_inliers = 0
  13. K0inv = np.linalg.inv(K0[:2,:2])
  14. K1inv = np.linalg.inv(K1[:2,:2])
  15. kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
  16. kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
  17. for _E in np.split(E, len(E) / 3):
  18. n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
  19. if n > best_num_inliers:
  20. best_num_inliers = n
  21. ret = (R, t, mask.ravel() > 0)
  22. return ret
  23. # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
  24. # --- GEOMETRY ---
  25. def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
  26. if len(kpts0) < 5:
  27. return None
  28. K0inv = np.linalg.inv(K0[:2,:2])
  29. K1inv = np.linalg.inv(K1[:2,:2])
  30. kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
  31. kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
  32. E, mask = cv2.findEssentialMat(
  33. kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
  34. )
  35. ret = None
  36. if E is not None:
  37. best_num_inliers = 0
  38. for _E in np.split(E, len(E) / 3):
  39. n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
  40. if n > best_num_inliers:
  41. best_num_inliers = n
  42. ret = (R, t, mask.ravel() > 0)
  43. return ret
  44. def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
  45. if len(kpts0) < 5:
  46. return None
  47. method = cv2.USAC_ACCURATE
  48. F, mask = cv2.findFundamentalMat(
  49. kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
  50. )
  51. E = K1.T@F@K0
  52. ret = None
  53. if E is not None:
  54. best_num_inliers = 0
  55. K0inv = np.linalg.inv(K0[:2,:2])
  56. K1inv = np.linalg.inv(K1[:2,:2])
  57. kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
  58. kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
  59. for _E in np.split(E, len(E) / 3):
  60. n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
  61. if n > best_num_inliers:
  62. best_num_inliers = n
  63. ret = (R, t, mask.ravel() > 0)
  64. return ret
  65. def unnormalize_coords(x_n,h,w):
  66. x = torch.stack(
  67. (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
  68. ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
  69. return x
  70. def rotate_intrinsic(K, n):
  71. base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
  72. rot = np.linalg.matrix_power(base_rot, n)
  73. return rot @ K
  74. def rotate_pose_inplane(i_T_w, rot):
  75. rotation_matrices = [
  76. np.array(
  77. [
  78. [np.cos(r), -np.sin(r), 0.0, 0.0],
  79. [np.sin(r), np.cos(r), 0.0, 0.0],
  80. [0.0, 0.0, 1.0, 0.0],
  81. [0.0, 0.0, 0.0, 1.0],
  82. ],
  83. dtype=np.float32,
  84. )
  85. for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
  86. ]
  87. return np.dot(rotation_matrices[rot], i_T_w)
  88. def scale_intrinsics(K, scales):
  89. scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
  90. return np.dot(scales, K)
  91. def to_homogeneous(points):
  92. return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
  93. def angle_error_mat(R1, R2):
  94. cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
  95. cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
  96. return np.rad2deg(np.abs(np.arccos(cos)))
  97. def angle_error_vec(v1, v2):
  98. n = np.linalg.norm(v1) * np.linalg.norm(v2)
  99. return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
  100. def compute_pose_error(T_0to1, R, t):
  101. R_gt = T_0to1[:3, :3]
  102. t_gt = T_0to1[:3, 3]
  103. error_t = angle_error_vec(t.squeeze(), t_gt)
  104. error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
  105. error_R = angle_error_mat(R, R_gt)
  106. return error_t, error_R
  107. def pose_auc(errors, thresholds):
  108. sort_idx = np.argsort(errors)
  109. errors = np.array(errors.copy())[sort_idx]
  110. recall = (np.arange(len(errors)) + 1) / len(errors)
  111. errors = np.r_[0.0, errors]
  112. recall = np.r_[0.0, recall]
  113. aucs = []
  114. for t in thresholds:
  115. last_index = np.searchsorted(errors, t)
  116. r = np.r_[recall[:last_index], recall[last_index - 1]]
  117. e = np.r_[errors[:last_index], t]
  118. aucs.append(np.trapz(r, x=e) / t)
  119. return aucs
  120. # From Patch2Pix https://github.com/GrumpyZhou/patch2pix
  121. def get_depth_tuple_transform_ops_nearest_exact(resize=None):
  122. ops = []
  123. if resize:
  124. ops.append(TupleResizeNearestExact(resize))
  125. return TupleCompose(ops)
  126. def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
  127. ops = []
  128. if resize:
  129. ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
  130. return TupleCompose(ops)
  131. def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
  132. ops = []
  133. if resize:
  134. ops.append(TupleResize(resize))
  135. ops.append(TupleToTensorScaled())
  136. if normalize:
  137. ops.append(
  138. TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  139. ) # Imagenet mean/std
  140. return TupleCompose(ops)
  141. class ToTensorScaled(object):
  142. """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
  143. def __call__(self, im):
  144. if not isinstance(im, torch.Tensor):
  145. im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
  146. im /= 255.0
  147. return torch.from_numpy(im)
  148. else:
  149. return im
  150. def __repr__(self):
  151. return "ToTensorScaled(./255)"
  152. class TupleToTensorScaled(object):
  153. def __init__(self):
  154. self.to_tensor = ToTensorScaled()
  155. def __call__(self, im_tuple):
  156. return [self.to_tensor(im) for im in im_tuple]
  157. def __repr__(self):
  158. return "TupleToTensorScaled(./255)"
  159. class ToTensorUnscaled(object):
  160. """Convert a RGB PIL Image to a CHW ordered Tensor"""
  161. def __call__(self, im):
  162. return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
  163. def __repr__(self):
  164. return "ToTensorUnscaled()"
  165. class TupleToTensorUnscaled(object):
  166. """Convert a RGB PIL Image to a CHW ordered Tensor"""
  167. def __init__(self):
  168. self.to_tensor = ToTensorUnscaled()
  169. def __call__(self, im_tuple):
  170. return [self.to_tensor(im) for im in im_tuple]
  171. def __repr__(self):
  172. return "TupleToTensorUnscaled()"
  173. class TupleResizeNearestExact:
  174. def __init__(self, size):
  175. self.size = size
  176. def __call__(self, im_tuple):
  177. return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
  178. def __repr__(self):
  179. return "TupleResizeNearestExact(size={})".format(self.size)
  180. class TupleResize(object):
  181. def __init__(self, size, mode=InterpolationMode.BICUBIC):
  182. self.size = size
  183. self.resize = transforms.Resize(size, mode)
  184. def __call__(self, im_tuple):
  185. return [self.resize(im) for im in im_tuple]
  186. def __repr__(self):
  187. return "TupleResize(size={})".format(self.size)
  188. class Normalize:
  189. def __call__(self,im):
  190. mean = im.mean(dim=(1,2), keepdims=True)
  191. std = im.std(dim=(1,2), keepdims=True)
  192. return (im-mean)/std
  193. class TupleNormalize(object):
  194. def __init__(self, mean, std):
  195. self.mean = mean
  196. self.std = std
  197. self.normalize = transforms.Normalize(mean=mean, std=std)
  198. def __call__(self, im_tuple):
  199. c,h,w = im_tuple[0].shape
  200. if c > 3:
  201. warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
  202. return [self.normalize(im[:3]) for im in im_tuple]
  203. def __repr__(self):
  204. return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
  205. class TupleCompose(object):
  206. def __init__(self, transforms):
  207. self.transforms = transforms
  208. def __call__(self, im_tuple):
  209. for t in self.transforms:
  210. im_tuple = t(im_tuple)
  211. return im_tuple
  212. def __repr__(self):
  213. format_string = self.__class__.__name__ + "("
  214. for t in self.transforms:
  215. format_string += "\n"
  216. format_string += " {0}".format(t)
  217. format_string += "\n)"
  218. return format_string
  219. @torch.no_grad()
  220. def cls_to_flow(cls, deterministic_sampling = True):
  221. B,C,H,W = cls.shape
  222. device = cls.device
  223. res = round(math.sqrt(C))
  224. G = torch.meshgrid(
  225. *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
  226. indexing = 'ij'
  227. )
  228. G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
  229. if deterministic_sampling:
  230. sampled_cls = cls.max(dim=1).indices
  231. else:
  232. sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
  233. flow = G[sampled_cls]
  234. return flow
  235. @torch.no_grad()
  236. def cls_to_flow_refine(cls):
  237. B,C,H,W = cls.shape
  238. device = cls.device
  239. res = round(math.sqrt(C))
  240. G = torch.meshgrid(
  241. *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
  242. indexing = 'ij'
  243. )
  244. G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
  245. cls = cls.softmax(dim=1)
  246. mode = cls.max(dim=1).indices
  247. index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
  248. neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
  249. flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
  250. tot_prob = neighbours.sum(dim=1)
  251. flow = flow / tot_prob
  252. return flow
  253. def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
  254. if H is None:
  255. B,H,W = depth1.shape
  256. else:
  257. B = depth1.shape[0]
  258. with torch.no_grad():
  259. x1_n = torch.meshgrid(
  260. *[
  261. torch.linspace(
  262. -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
  263. )
  264. for n in (B, H, W)
  265. ],
  266. indexing = 'ij'
  267. )
  268. x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
  269. mask, x2 = warp_kpts(
  270. x1_n.double(),
  271. depth1.double(),
  272. depth2.double(),
  273. T_1to2.double(),
  274. K1.double(),
  275. K2.double(),
  276. depth_interpolation_mode = depth_interpolation_mode,
  277. relative_depth_error_threshold = relative_depth_error_threshold,
  278. )
  279. prob = mask.float().reshape(B, H, W)
  280. x2 = x2.reshape(B, H, W, 2)
  281. return x2, prob
  282. @torch.no_grad()
  283. def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
  284. """Warp kpts0 from I0 to I1 with depth, K and Rt
  285. Also check covisibility and depth consistency.
  286. Depth is consistent if relative error < 0.2 (hard-coded).
  287. # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
  288. Args:
  289. kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
  290. depth0 (torch.Tensor): [N, H, W],
  291. depth1 (torch.Tensor): [N, H, W],
  292. T_0to1 (torch.Tensor): [N, 3, 4],
  293. K0 (torch.Tensor): [N, 3, 3],
  294. K1 (torch.Tensor): [N, 3, 3],
  295. Returns:
  296. calculable_mask (torch.Tensor): [N, L]
  297. warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
  298. """
  299. (
  300. n,
  301. h,
  302. w,
  303. ) = depth0.shape
  304. if depth_interpolation_mode == "combined":
  305. # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
  306. if smooth_mask:
  307. raise NotImplementedError("Combined bilinear and NN warp not implemented")
  308. valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
  309. smooth_mask = smooth_mask,
  310. return_relative_depth_error = return_relative_depth_error,
  311. depth_interpolation_mode = "bilinear",
  312. relative_depth_error_threshold = relative_depth_error_threshold)
  313. valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
  314. smooth_mask = smooth_mask,
  315. return_relative_depth_error = return_relative_depth_error,
  316. depth_interpolation_mode = "nearest-exact",
  317. relative_depth_error_threshold = relative_depth_error_threshold)
  318. nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
  319. warp = warp_bilinear.clone()
  320. warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
  321. valid = valid_bilinear | valid_nearest
  322. return valid, warp
  323. kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
  324. :, 0, :, 0
  325. ]
  326. kpts0 = torch.stack(
  327. (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
  328. ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
  329. # Sample depth, get calculable_mask on depth != 0
  330. nonzero_mask = kpts0_depth != 0
  331. # Unproject
  332. kpts0_h = (
  333. torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
  334. * kpts0_depth[..., None]
  335. ) # (N, L, 3)
  336. kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
  337. kpts0_cam = kpts0_n
  338. # Rigid Transform
  339. w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
  340. w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
  341. # Project
  342. w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
  343. w_kpts0 = w_kpts0_h[:, :, :2] / (
  344. w_kpts0_h[:, :, [2]] + 1e-4
  345. ) # (N, L, 2), +1e-4 to avoid zero depth
  346. # Covisible Check
  347. h, w = depth1.shape[1:3]
  348. covisible_mask = (
  349. (w_kpts0[:, :, 0] > 0)
  350. * (w_kpts0[:, :, 0] < w - 1)
  351. * (w_kpts0[:, :, 1] > 0)
  352. * (w_kpts0[:, :, 1] < h - 1)
  353. )
  354. w_kpts0 = torch.stack(
  355. (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
  356. ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
  357. # w_kpts0[~covisible_mask, :] = -5 # xd
  358. w_kpts0_depth = F.grid_sample(
  359. depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
  360. )[:, 0, :, 0]
  361. relative_depth_error = (
  362. (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
  363. ).abs()
  364. if not smooth_mask:
  365. consistent_mask = relative_depth_error < relative_depth_error_threshold
  366. else:
  367. consistent_mask = (-relative_depth_error/smooth_mask).exp()
  368. valid_mask = nonzero_mask * covisible_mask * consistent_mask
  369. if return_relative_depth_error:
  370. return relative_depth_error, w_kpts0
  371. else:
  372. return valid_mask, w_kpts0
  373. imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
  374. imagenet_std = torch.tensor([0.229, 0.224, 0.225])
  375. def numpy_to_pil(x: np.ndarray):
  376. """
  377. Args:
  378. x: Assumed to be of shape (h,w,c)
  379. """
  380. if isinstance(x, torch.Tensor):
  381. x = x.detach().cpu().numpy()
  382. if x.max() <= 1.01:
  383. x *= 255
  384. x = x.astype(np.uint8)
  385. return Image.fromarray(x)
  386. def tensor_to_pil(x, unnormalize=False):
  387. if unnormalize:
  388. x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
  389. x = x.detach().permute(1, 2, 0).cpu().numpy()
  390. x = np.clip(x, 0.0, 1.0)
  391. return numpy_to_pil(x)
  392. def to_cuda(batch):
  393. for key, value in batch.items():
  394. if isinstance(value, torch.Tensor):
  395. batch[key] = value.cuda()
  396. return batch
  397. def to_cpu(batch):
  398. for key, value in batch.items():
  399. if isinstance(value, torch.Tensor):
  400. batch[key] = value.cpu()
  401. return batch
  402. def get_pose(calib):
  403. w, h = np.array(calib["imsize"])[0]
  404. return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
  405. def compute_relative_pose(R1, t1, R2, t2):
  406. rots = R2 @ (R1.T)
  407. trans = -rots @ t1 + t2
  408. return rots, trans
  409. @torch.no_grad()
  410. def reset_opt(opt):
  411. for group in opt.param_groups:
  412. for p in group['params']:
  413. if p.requires_grad:
  414. state = opt.state[p]
  415. # State initialization
  416. # Exponential moving average of gradient values
  417. state['exp_avg'] = torch.zeros_like(p)
  418. # Exponential moving average of squared gradient values
  419. state['exp_avg_sq'] = torch.zeros_like(p)
  420. # Exponential moving average of gradient difference
  421. state['exp_avg_diff'] = torch.zeros_like(p)
  422. def flow_to_pixel_coords(flow, h1, w1):
  423. flow = (
  424. torch.stack(
  425. (
  426. w1 * (flow[..., 0] + 1) / 2,
  427. h1 * (flow[..., 1] + 1) / 2,
  428. ),
  429. axis=-1,
  430. )
  431. )
  432. return flow
  433. to_pixel_coords = flow_to_pixel_coords # just an alias
  434. def flow_to_normalized_coords(flow, h1, w1):
  435. flow = (
  436. torch.stack(
  437. (
  438. 2 * (flow[..., 0]) / w1 - 1,
  439. 2 * (flow[..., 1]) / h1 - 1,
  440. ),
  441. axis=-1,
  442. )
  443. )
  444. return flow
  445. to_normalized_coords = flow_to_normalized_coords # just an alias
  446. def warp_to_pixel_coords(warp, h1, w1, h2, w2):
  447. warp1 = warp[..., :2]
  448. warp1 = (
  449. torch.stack(
  450. (
  451. w1 * (warp1[..., 0] + 1) / 2,
  452. h1 * (warp1[..., 1] + 1) / 2,
  453. ),
  454. axis=-1,
  455. )
  456. )
  457. warp2 = warp[..., 2:]
  458. warp2 = (
  459. torch.stack(
  460. (
  461. w2 * (warp2[..., 0] + 1) / 2,
  462. h2 * (warp2[..., 1] + 1) / 2,
  463. ),
  464. axis=-1,
  465. )
  466. )
  467. return torch.cat((warp1,warp2), dim=-1)
  468. def signed_point_line_distance(point, line, eps: float = 1e-9):
  469. r"""Return the distance from points to lines.
  470. Args:
  471. point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`.
  472. line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`.
  473. eps: Small constant for safe sqrt.
  474. Returns:
  475. the computed distance with shape :math:`(*, N)`.
  476. """
  477. if not point.shape[-1] in (2, 3):
  478. raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}")
  479. if not line.shape[-1] == 3:
  480. raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
  481. numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
  482. denominator = line[..., :2].norm(dim=-1)
  483. return numerator / (denominator + eps)
  484. def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
  485. r"""Return one-sided epipolar distance for correspondences given the fundamental matrix.
  486. This method measures the distance from points in the right images to the epilines
  487. of the corresponding points in the left images as they reflect in the right images.
  488. Args:
  489. pts1: correspondences from the left images with shape
  490. :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
  491. pts2: correspondences from the right images with shape
  492. :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
  493. Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to
  494. avoid ambiguity with torch.nn.functional.
  495. Returns:
  496. the computed Symmetrical distance with shape :math:`(*, N)`.
  497. """
  498. import kornia
  499. if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
  500. raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
  501. if pts1.shape[-1] == 2:
  502. pts1 = kornia.geometry.convert_points_to_homogeneous(pts1)
  503. F_t = Fm.transpose(dim0=-2, dim1=-1)
  504. line1_in_2 = pts1 @ F_t
  505. return signed_point_line_distance(pts2, line1_in_2)
  506. def get_grid(b, h, w, device):
  507. grid = torch.meshgrid(
  508. *[
  509. torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
  510. for n in (b, h, w)
  511. ],
  512. indexing = 'ij'
  513. )
  514. grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
  515. return grid
  516. def get_autocast_params(device=None, enabled=False, dtype=None):
  517. if device is None:
  518. autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
  519. else:
  520. #strip :X from device
  521. autocast_device = str(device).split(":")[0]
  522. if 'cuda' in str(device):
  523. out_dtype = dtype
  524. enabled = True
  525. else:
  526. out_dtype = torch.bfloat16
  527. enabled = False
  528. return autocast_device, enabled, out_dtype