utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. import warnings
  2. import numpy as np
  3. import cv2
  4. import math
  5. import torch
  6. from torchvision import transforms
  7. from torchvision.transforms.functional import InterpolationMode
  8. import torch.nn.functional as F
  9. from PIL import Image
  10. def recover_pose(E, kpts0, kpts1, K0, K1, mask):
  11. best_num_inliers = 0
  12. K0inv = np.linalg.inv(K0[:2,:2])
  13. K1inv = np.linalg.inv(K1[:2,:2])
  14. kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
  15. kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
  16. for _E in np.split(E, len(E) / 3):
  17. n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
  18. if n > best_num_inliers:
  19. best_num_inliers = n
  20. ret = (R, t, mask.ravel() > 0)
  21. return ret
  22. # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
  23. # --- GEOMETRY ---
  24. def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
  25. if len(kpts0) < 5:
  26. return None
  27. K0inv = np.linalg.inv(K0[:2,:2])
  28. K1inv = np.linalg.inv(K1[:2,:2])
  29. kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
  30. kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
  31. E, mask = cv2.findEssentialMat(
  32. kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
  33. )
  34. ret = None
  35. if E is not None:
  36. best_num_inliers = 0
  37. for _E in np.split(E, len(E) / 3):
  38. n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
  39. if n > best_num_inliers:
  40. best_num_inliers = n
  41. ret = (R, t, mask.ravel() > 0)
  42. return ret
  43. def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
  44. if len(kpts0) < 5:
  45. return None
  46. method = cv2.USAC_ACCURATE
  47. F, mask = cv2.findFundamentalMat(
  48. kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
  49. )
  50. E = K1.T@F@K0
  51. ret = None
  52. if E is not None:
  53. best_num_inliers = 0
  54. K0inv = np.linalg.inv(K0[:2,:2])
  55. K1inv = np.linalg.inv(K1[:2,:2])
  56. kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
  57. kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
  58. for _E in np.split(E, len(E) / 3):
  59. n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
  60. if n > best_num_inliers:
  61. best_num_inliers = n
  62. ret = (R, t, mask.ravel() > 0)
  63. return ret
  64. def unnormalize_coords(x_n,h,w):
  65. x = torch.stack(
  66. (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
  67. ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
  68. return x
  69. def rotate_intrinsic(K, n):
  70. base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
  71. rot = np.linalg.matrix_power(base_rot, n)
  72. return rot @ K
  73. def rotate_pose_inplane(i_T_w, rot):
  74. rotation_matrices = [
  75. np.array(
  76. [
  77. [np.cos(r), -np.sin(r), 0.0, 0.0],
  78. [np.sin(r), np.cos(r), 0.0, 0.0],
  79. [0.0, 0.0, 1.0, 0.0],
  80. [0.0, 0.0, 0.0, 1.0],
  81. ],
  82. dtype=np.float32,
  83. )
  84. for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
  85. ]
  86. return np.dot(rotation_matrices[rot], i_T_w)
  87. def scale_intrinsics(K, scales):
  88. scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
  89. return np.dot(scales, K)
  90. def to_homogeneous(points):
  91. return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
  92. def angle_error_mat(R1, R2):
  93. cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
  94. cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
  95. return np.rad2deg(np.abs(np.arccos(cos)))
  96. def angle_error_vec(v1, v2):
  97. n = np.linalg.norm(v1) * np.linalg.norm(v2)
  98. return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
  99. def compute_pose_error(T_0to1, R, t):
  100. R_gt = T_0to1[:3, :3]
  101. t_gt = T_0to1[:3, 3]
  102. error_t = angle_error_vec(t.squeeze(), t_gt)
  103. error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
  104. error_R = angle_error_mat(R, R_gt)
  105. return error_t, error_R
  106. def pose_auc(errors, thresholds):
  107. sort_idx = np.argsort(errors)
  108. errors = np.array(errors.copy())[sort_idx]
  109. recall = (np.arange(len(errors)) + 1) / len(errors)
  110. errors = np.r_[0.0, errors]
  111. recall = np.r_[0.0, recall]
  112. aucs = []
  113. for t in thresholds:
  114. last_index = np.searchsorted(errors, t)
  115. r = np.r_[recall[:last_index], recall[last_index - 1]]
  116. e = np.r_[errors[:last_index], t]
  117. aucs.append(np.trapz(r, x=e).item() / t)
  118. return aucs
  119. # From Patch2Pix https://github.com/GrumpyZhou/patch2pix
  120. def get_depth_tuple_transform_ops_nearest_exact(resize=None):
  121. ops = []
  122. if resize:
  123. ops.append(TupleResizeNearestExact(resize))
  124. return TupleCompose(ops)
  125. def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
  126. ops = []
  127. if resize:
  128. ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
  129. return TupleCompose(ops)
  130. def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
  131. ops = []
  132. if resize:
  133. ops.append(TupleResize(resize))
  134. ops.append(TupleToTensorScaled())
  135. if normalize:
  136. ops.append(
  137. TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  138. ) # Imagenet mean/std
  139. return TupleCompose(ops)
  140. class ToTensorScaled(object):
  141. """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
  142. def __call__(self, im):
  143. if not isinstance(im, torch.Tensor):
  144. im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
  145. im /= 255.0
  146. return torch.from_numpy(im)
  147. else:
  148. return im
  149. def __repr__(self):
  150. return "ToTensorScaled(./255)"
  151. class TupleToTensorScaled(object):
  152. def __init__(self):
  153. self.to_tensor = ToTensorScaled()
  154. def __call__(self, im_tuple):
  155. return [self.to_tensor(im) for im in im_tuple]
  156. def __repr__(self):
  157. return "TupleToTensorScaled(./255)"
  158. class ToTensorUnscaled(object):
  159. """Convert a RGB PIL Image to a CHW ordered Tensor"""
  160. def __call__(self, im):
  161. return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
  162. def __repr__(self):
  163. return "ToTensorUnscaled()"
  164. class TupleToTensorUnscaled(object):
  165. """Convert a RGB PIL Image to a CHW ordered Tensor"""
  166. def __init__(self):
  167. self.to_tensor = ToTensorUnscaled()
  168. def __call__(self, im_tuple):
  169. return [self.to_tensor(im) for im in im_tuple]
  170. def __repr__(self):
  171. return "TupleToTensorUnscaled()"
  172. class TupleResizeNearestExact:
  173. def __init__(self, size):
  174. self.size = size
  175. def __call__(self, im_tuple):
  176. return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
  177. def __repr__(self):
  178. return "TupleResizeNearestExact(size={})".format(self.size)
  179. class TupleResize(object):
  180. def __init__(self, size, mode=InterpolationMode.BICUBIC):
  181. self.size = size
  182. self.resize = transforms.Resize(size, mode)
  183. def __call__(self, im_tuple):
  184. return [self.resize(im) for im in im_tuple]
  185. def __repr__(self):
  186. return "TupleResize(size={})".format(self.size)
  187. class Normalize:
  188. def __call__(self,im):
  189. mean = im.mean(dim=(1,2), keepdims=True)
  190. std = im.std(dim=(1,2), keepdims=True)
  191. return (im-mean)/std
  192. class TupleNormalize(object):
  193. def __init__(self, mean, std):
  194. self.mean = mean
  195. self.std = std
  196. self.normalize = transforms.Normalize(mean=mean, std=std)
  197. def __call__(self, im_tuple):
  198. c,h,w = im_tuple[0].shape
  199. if c > 3:
  200. warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
  201. return [self.normalize(im[:3]) for im in im_tuple]
  202. def __repr__(self):
  203. return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
  204. class TupleCompose(object):
  205. def __init__(self, transforms):
  206. self.transforms = transforms
  207. def __call__(self, im_tuple):
  208. for t in self.transforms:
  209. im_tuple = t(im_tuple)
  210. return im_tuple
  211. def __repr__(self):
  212. format_string = self.__class__.__name__ + "("
  213. for t in self.transforms:
  214. format_string += "\n"
  215. format_string += " {0}".format(t)
  216. format_string += "\n)"
  217. return format_string
  218. @torch.no_grad()
  219. def cls_to_flow(cls, deterministic_sampling = True):
  220. B,C,H,W = cls.shape
  221. device = cls.device
  222. res = round(math.sqrt(C))
  223. G = torch.meshgrid(
  224. *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
  225. indexing = 'ij'
  226. )
  227. G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
  228. if deterministic_sampling:
  229. sampled_cls = cls.max(dim=1).indices
  230. else:
  231. sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
  232. flow = G[sampled_cls]
  233. return flow
  234. @torch.no_grad()
  235. def cls_to_flow_refine(cls):
  236. B,C,H,W = cls.shape
  237. device = cls.device
  238. res = round(math.sqrt(C))
  239. G = torch.meshgrid(
  240. *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
  241. indexing = 'ij'
  242. )
  243. G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
  244. # FIXME: below softmax line causes mps to bug, don't know why.
  245. if device.type == 'mps':
  246. cls = cls.log_softmax(dim=1).exp()
  247. else:
  248. cls = cls.softmax(dim=1)
  249. mode = cls.max(dim=1).indices
  250. index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
  251. neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
  252. flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
  253. tot_prob = neighbours.sum(dim=1)
  254. flow = flow / tot_prob
  255. return flow
  256. def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
  257. if H is None:
  258. B,H,W = depth1.shape
  259. else:
  260. B = depth1.shape[0]
  261. with torch.no_grad():
  262. x1_n = torch.meshgrid(
  263. *[
  264. torch.linspace(
  265. -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
  266. )
  267. for n in (B, H, W)
  268. ],
  269. indexing = 'ij'
  270. )
  271. x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
  272. mask, x2 = warp_kpts(
  273. x1_n.double(),
  274. depth1.double(),
  275. depth2.double(),
  276. T_1to2.double(),
  277. K1.double(),
  278. K2.double(),
  279. depth_interpolation_mode = depth_interpolation_mode,
  280. relative_depth_error_threshold = relative_depth_error_threshold,
  281. )
  282. prob = mask.float().reshape(B, H, W)
  283. x2 = x2.reshape(B, H, W, 2)
  284. return x2, prob
  285. @torch.no_grad()
  286. def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
  287. """Warp kpts0 from I0 to I1 with depth, K and Rt
  288. Also check covisibility and depth consistency.
  289. Depth is consistent if relative error < 0.2 (hard-coded).
  290. # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
  291. Args:
  292. kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
  293. depth0 (torch.Tensor): [N, H, W],
  294. depth1 (torch.Tensor): [N, H, W],
  295. T_0to1 (torch.Tensor): [N, 3, 4],
  296. K0 (torch.Tensor): [N, 3, 3],
  297. K1 (torch.Tensor): [N, 3, 3],
  298. Returns:
  299. calculable_mask (torch.Tensor): [N, L]
  300. warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
  301. """
  302. (
  303. n,
  304. h,
  305. w,
  306. ) = depth0.shape
  307. if depth_interpolation_mode == "combined":
  308. # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
  309. if smooth_mask:
  310. raise NotImplementedError("Combined bilinear and NN warp not implemented")
  311. valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
  312. smooth_mask = smooth_mask,
  313. return_relative_depth_error = return_relative_depth_error,
  314. depth_interpolation_mode = "bilinear",
  315. relative_depth_error_threshold = relative_depth_error_threshold)
  316. valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
  317. smooth_mask = smooth_mask,
  318. return_relative_depth_error = return_relative_depth_error,
  319. depth_interpolation_mode = "nearest-exact",
  320. relative_depth_error_threshold = relative_depth_error_threshold)
  321. nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
  322. warp = warp_bilinear.clone()
  323. warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
  324. valid = valid_bilinear | valid_nearest
  325. return valid, warp
  326. kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
  327. :, 0, :, 0
  328. ]
  329. kpts0 = torch.stack(
  330. (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
  331. ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
  332. # Sample depth, get calculable_mask on depth != 0
  333. nonzero_mask = kpts0_depth != 0
  334. # Unproject
  335. kpts0_h = (
  336. torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
  337. * kpts0_depth[..., None]
  338. ) # (N, L, 3)
  339. kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
  340. kpts0_cam = kpts0_n
  341. # Rigid Transform
  342. w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
  343. w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
  344. # Project
  345. w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
  346. w_kpts0 = w_kpts0_h[:, :, :2] / (
  347. w_kpts0_h[:, :, [2]] + 1e-4
  348. ) # (N, L, 2), +1e-4 to avoid zero depth
  349. # Covisible Check
  350. h, w = depth1.shape[1:3]
  351. covisible_mask = (
  352. (w_kpts0[:, :, 0] > 0)
  353. * (w_kpts0[:, :, 0] < w - 1)
  354. * (w_kpts0[:, :, 1] > 0)
  355. * (w_kpts0[:, :, 1] < h - 1)
  356. )
  357. w_kpts0 = torch.stack(
  358. (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
  359. ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
  360. # w_kpts0[~covisible_mask, :] = -5 # xd
  361. w_kpts0_depth = F.grid_sample(
  362. depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
  363. )[:, 0, :, 0]
  364. relative_depth_error = (
  365. (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
  366. ).abs()
  367. if not smooth_mask:
  368. consistent_mask = relative_depth_error < relative_depth_error_threshold
  369. else:
  370. consistent_mask = (-relative_depth_error/smooth_mask).exp()
  371. valid_mask = nonzero_mask * covisible_mask * consistent_mask
  372. if return_relative_depth_error:
  373. return relative_depth_error, w_kpts0
  374. else:
  375. return valid_mask, w_kpts0
  376. imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
  377. imagenet_std = torch.tensor([0.229, 0.224, 0.225])
  378. def numpy_to_pil(x: np.ndarray):
  379. """
  380. Args:
  381. x: Assumed to be of shape (h,w,c)
  382. """
  383. if isinstance(x, torch.Tensor):
  384. x = x.detach().cpu().numpy()
  385. if x.max() <= 1.01:
  386. x *= 255
  387. x = x.astype(np.uint8)
  388. return Image.fromarray(x)
  389. def tensor_to_pil(x, unnormalize=False):
  390. if unnormalize:
  391. x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
  392. x = x.detach().permute(1, 2, 0).cpu().numpy()
  393. x = np.clip(x, 0.0, 1.0)
  394. return numpy_to_pil(x)
  395. def to_cuda(batch):
  396. for key, value in batch.items():
  397. if isinstance(value, torch.Tensor):
  398. batch[key] = value.cuda()
  399. return batch
  400. def to_cpu(batch):
  401. for key, value in batch.items():
  402. if isinstance(value, torch.Tensor):
  403. batch[key] = value.cpu()
  404. return batch
  405. def get_pose(calib):
  406. w, h = np.array(calib["imsize"])[0]
  407. return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
  408. def compute_relative_pose(R1, t1, R2, t2):
  409. rots = R2 @ (R1.T)
  410. trans = -rots @ t1 + t2
  411. return rots, trans
  412. @torch.no_grad()
  413. def reset_opt(opt):
  414. for group in opt.param_groups:
  415. for p in group['params']:
  416. if p.requires_grad:
  417. state = opt.state[p]
  418. # State initialization
  419. # Exponential moving average of gradient values
  420. state['exp_avg'] = torch.zeros_like(p)
  421. # Exponential moving average of squared gradient values
  422. state['exp_avg_sq'] = torch.zeros_like(p)
  423. # Exponential moving average of gradient difference
  424. state['exp_avg_diff'] = torch.zeros_like(p)
  425. def flow_to_pixel_coords(flow, h1, w1):
  426. flow = (
  427. torch.stack(
  428. (
  429. w1 * (flow[..., 0] + 1) / 2,
  430. h1 * (flow[..., 1] + 1) / 2,
  431. ),
  432. axis=-1,
  433. )
  434. )
  435. return flow
  436. to_pixel_coords = flow_to_pixel_coords # just an alias
  437. def flow_to_normalized_coords(flow, h1, w1):
  438. flow = (
  439. torch.stack(
  440. (
  441. 2 * (flow[..., 0]) / w1 - 1,
  442. 2 * (flow[..., 1]) / h1 - 1,
  443. ),
  444. axis=-1,
  445. )
  446. )
  447. return flow
  448. to_normalized_coords = flow_to_normalized_coords # just an alias
  449. def warp_to_pixel_coords(warp, h1, w1, h2, w2):
  450. warp1 = warp[..., :2]
  451. warp1 = (
  452. torch.stack(
  453. (
  454. w1 * (warp1[..., 0] + 1) / 2,
  455. h1 * (warp1[..., 1] + 1) / 2,
  456. ),
  457. axis=-1,
  458. )
  459. )
  460. warp2 = warp[..., 2:]
  461. warp2 = (
  462. torch.stack(
  463. (
  464. w2 * (warp2[..., 0] + 1) / 2,
  465. h2 * (warp2[..., 1] + 1) / 2,
  466. ),
  467. axis=-1,
  468. )
  469. )
  470. return torch.cat((warp1,warp2), dim=-1)
  471. def signed_point_line_distance(point, line, eps: float = 1e-9):
  472. r"""Return the distance from points to lines.
  473. Args:
  474. point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`.
  475. line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`.
  476. eps: Small constant for safe sqrt.
  477. Returns:
  478. the computed distance with shape :math:`(*, N)`.
  479. """
  480. if not point.shape[-1] in (2, 3):
  481. raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}")
  482. if not line.shape[-1] == 3:
  483. raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
  484. numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
  485. denominator = line[..., :2].norm(dim=-1)
  486. return numerator / (denominator + eps)
  487. def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
  488. r"""Return one-sided epipolar distance for correspondences given the fundamental matrix.
  489. This method measures the distance from points in the right images to the epilines
  490. of the corresponding points in the left images as they reflect in the right images.
  491. Args:
  492. pts1: correspondences from the left images with shape
  493. :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
  494. pts2: correspondences from the right images with shape
  495. :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
  496. Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to
  497. avoid ambiguity with torch.nn.functional.
  498. Returns:
  499. the computed Symmetrical distance with shape :math:`(*, N)`.
  500. """
  501. import kornia
  502. if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
  503. raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
  504. if pts1.shape[-1] == 2:
  505. pts1 = kornia.geometry.convert_points_to_homogeneous(pts1)
  506. F_t = Fm.transpose(dim0=-2, dim1=-1)
  507. line1_in_2 = pts1 @ F_t
  508. return signed_point_line_distance(pts2, line1_in_2)
  509. def get_grid(b, h, w, device):
  510. grid = torch.meshgrid(
  511. *[
  512. torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
  513. for n in (b, h, w)
  514. ],
  515. indexing = 'ij'
  516. )
  517. grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
  518. return grid
  519. def get_autocast_params(device=None, enabled=False, dtype=None):
  520. if device is None:
  521. autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
  522. else:
  523. #strip :X from device
  524. autocast_device = str(device).split(":")[0]
  525. if 'cuda' in str(device):
  526. out_dtype = dtype
  527. enabled = True
  528. else:
  529. out_dtype = torch.bfloat16
  530. enabled = False
  531. # mps is not supported
  532. autocast_device = "cpu"
  533. return autocast_device, enabled, out_dtype
  534. def check_not_i16(im):
  535. if im.mode == "I;16":
  536. raise NotImplementedError("Can't handle 16 bit images")
  537. def check_rgb(im):
  538. if im.mode != "RGB":
  539. raise NotImplementedError("Can't handle non-RGB images")