| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661 |
- import warnings
- import numpy as np
- import cv2
- import math
- import torch
- from torchvision import transforms
- from torchvision.transforms.functional import InterpolationMode
- import torch.nn.functional as F
- from PIL import Image
- def recover_pose(E, kpts0, kpts1, K0, K1, mask):
- best_num_inliers = 0
- K0inv = np.linalg.inv(K0[:2,:2])
- K1inv = np.linalg.inv(K1[:2,:2])
- kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
- kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
- for _E in np.split(E, len(E) / 3):
- n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
- if n > best_num_inliers:
- best_num_inliers = n
- ret = (R, t, mask.ravel() > 0)
- return ret
- # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
- # --- GEOMETRY ---
- def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
- if len(kpts0) < 5:
- return None
- K0inv = np.linalg.inv(K0[:2,:2])
- K1inv = np.linalg.inv(K1[:2,:2])
- kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
- kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
- E, mask = cv2.findEssentialMat(
- kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
- )
- ret = None
- if E is not None:
- best_num_inliers = 0
- for _E in np.split(E, len(E) / 3):
- n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
- if n > best_num_inliers:
- best_num_inliers = n
- ret = (R, t, mask.ravel() > 0)
- return ret
- def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
- if len(kpts0) < 5:
- return None
- method = cv2.USAC_ACCURATE
- F, mask = cv2.findFundamentalMat(
- kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
- )
- E = K1.T@F@K0
- ret = None
- if E is not None:
- best_num_inliers = 0
- K0inv = np.linalg.inv(K0[:2,:2])
- K1inv = np.linalg.inv(K1[:2,:2])
- kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
- kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
-
- for _E in np.split(E, len(E) / 3):
- n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
- if n > best_num_inliers:
- best_num_inliers = n
- ret = (R, t, mask.ravel() > 0)
- return ret
- def unnormalize_coords(x_n,h,w):
- x = torch.stack(
- (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
- ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
- return x
- def rotate_intrinsic(K, n):
- base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
- rot = np.linalg.matrix_power(base_rot, n)
- return rot @ K
- def rotate_pose_inplane(i_T_w, rot):
- rotation_matrices = [
- np.array(
- [
- [np.cos(r), -np.sin(r), 0.0, 0.0],
- [np.sin(r), np.cos(r), 0.0, 0.0],
- [0.0, 0.0, 1.0, 0.0],
- [0.0, 0.0, 0.0, 1.0],
- ],
- dtype=np.float32,
- )
- for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
- ]
- return np.dot(rotation_matrices[rot], i_T_w)
- def scale_intrinsics(K, scales):
- scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
- return np.dot(scales, K)
- def to_homogeneous(points):
- return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
- def angle_error_mat(R1, R2):
- cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
- cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
- return np.rad2deg(np.abs(np.arccos(cos)))
- def angle_error_vec(v1, v2):
- n = np.linalg.norm(v1) * np.linalg.norm(v2)
- return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
- def compute_pose_error(T_0to1, R, t):
- R_gt = T_0to1[:3, :3]
- t_gt = T_0to1[:3, 3]
- error_t = angle_error_vec(t.squeeze(), t_gt)
- error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
- error_R = angle_error_mat(R, R_gt)
- return error_t, error_R
- def pose_auc(errors, thresholds):
- sort_idx = np.argsort(errors)
- errors = np.array(errors.copy())[sort_idx]
- recall = (np.arange(len(errors)) + 1) / len(errors)
- errors = np.r_[0.0, errors]
- recall = np.r_[0.0, recall]
- aucs = []
- for t in thresholds:
- last_index = np.searchsorted(errors, t)
- r = np.r_[recall[:last_index], recall[last_index - 1]]
- e = np.r_[errors[:last_index], t]
- aucs.append(np.trapz(r, x=e).item() / t)
- return aucs
- # From Patch2Pix https://github.com/GrumpyZhou/patch2pix
- def get_depth_tuple_transform_ops_nearest_exact(resize=None):
- ops = []
- if resize:
- ops.append(TupleResizeNearestExact(resize))
- return TupleCompose(ops)
- def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
- ops = []
- if resize:
- ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
- return TupleCompose(ops)
- def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
- ops = []
- if resize:
- ops.append(TupleResize(resize))
- ops.append(TupleToTensorScaled())
- if normalize:
- ops.append(
- TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ) # Imagenet mean/std
- return TupleCompose(ops)
- class ToTensorScaled(object):
- """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
- def __call__(self, im):
- if not isinstance(im, torch.Tensor):
- im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
- im /= 255.0
- return torch.from_numpy(im)
- else:
- return im
- def __repr__(self):
- return "ToTensorScaled(./255)"
- class TupleToTensorScaled(object):
- def __init__(self):
- self.to_tensor = ToTensorScaled()
- def __call__(self, im_tuple):
- return [self.to_tensor(im) for im in im_tuple]
- def __repr__(self):
- return "TupleToTensorScaled(./255)"
- class ToTensorUnscaled(object):
- """Convert a RGB PIL Image to a CHW ordered Tensor"""
- def __call__(self, im):
- return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
- def __repr__(self):
- return "ToTensorUnscaled()"
- class TupleToTensorUnscaled(object):
- """Convert a RGB PIL Image to a CHW ordered Tensor"""
- def __init__(self):
- self.to_tensor = ToTensorUnscaled()
- def __call__(self, im_tuple):
- return [self.to_tensor(im) for im in im_tuple]
- def __repr__(self):
- return "TupleToTensorUnscaled()"
- class TupleResizeNearestExact:
- def __init__(self, size):
- self.size = size
- def __call__(self, im_tuple):
- return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
- def __repr__(self):
- return "TupleResizeNearestExact(size={})".format(self.size)
- class TupleResize(object):
- def __init__(self, size, mode=InterpolationMode.BICUBIC):
- self.size = size
- self.resize = transforms.Resize(size, mode)
- def __call__(self, im_tuple):
- return [self.resize(im) for im in im_tuple]
- def __repr__(self):
- return "TupleResize(size={})".format(self.size)
-
- class Normalize:
- def __call__(self,im):
- mean = im.mean(dim=(1,2), keepdims=True)
- std = im.std(dim=(1,2), keepdims=True)
- return (im-mean)/std
- class TupleNormalize(object):
- def __init__(self, mean, std):
- self.mean = mean
- self.std = std
- self.normalize = transforms.Normalize(mean=mean, std=std)
- def __call__(self, im_tuple):
- c,h,w = im_tuple[0].shape
- if c > 3:
- warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
- return [self.normalize(im[:3]) for im in im_tuple]
- def __repr__(self):
- return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
- class TupleCompose(object):
- def __init__(self, transforms):
- self.transforms = transforms
- def __call__(self, im_tuple):
- for t in self.transforms:
- im_tuple = t(im_tuple)
- return im_tuple
- def __repr__(self):
- format_string = self.__class__.__name__ + "("
- for t in self.transforms:
- format_string += "\n"
- format_string += " {0}".format(t)
- format_string += "\n)"
- return format_string
- @torch.no_grad()
- def cls_to_flow(cls, deterministic_sampling = True):
- B,C,H,W = cls.shape
- device = cls.device
- res = round(math.sqrt(C))
- G = torch.meshgrid(
- *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
- indexing = 'ij'
- )
- G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
- if deterministic_sampling:
- sampled_cls = cls.max(dim=1).indices
- else:
- sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
- flow = G[sampled_cls]
- return flow
- @torch.no_grad()
- def cls_to_flow_refine(cls):
- B,C,H,W = cls.shape
- device = cls.device
- res = round(math.sqrt(C))
- G = torch.meshgrid(
- *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
- indexing = 'ij'
- )
- G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
- # FIXME: below softmax line causes mps to bug, don't know why.
- if device.type == 'mps':
- cls = cls.log_softmax(dim=1).exp()
- else:
- cls = cls.softmax(dim=1)
- mode = cls.max(dim=1).indices
-
- index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
- neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
- flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
- tot_prob = neighbours.sum(dim=1)
- flow = flow / tot_prob
- return flow
- def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
-
- if H is None:
- B,H,W = depth1.shape
- else:
- B = depth1.shape[0]
- with torch.no_grad():
- x1_n = torch.meshgrid(
- *[
- torch.linspace(
- -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
- )
- for n in (B, H, W)
- ],
- indexing = 'ij'
- )
- x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
- mask, x2 = warp_kpts(
- x1_n.double(),
- depth1.double(),
- depth2.double(),
- T_1to2.double(),
- K1.double(),
- K2.double(),
- depth_interpolation_mode = depth_interpolation_mode,
- relative_depth_error_threshold = relative_depth_error_threshold,
- )
- prob = mask.float().reshape(B, H, W)
- x2 = x2.reshape(B, H, W, 2)
- return x2, prob
- @torch.no_grad()
- def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
- """Warp kpts0 from I0 to I1 with depth, K and Rt
- Also check covisibility and depth consistency.
- Depth is consistent if relative error < 0.2 (hard-coded).
- # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
- Args:
- kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
- depth0 (torch.Tensor): [N, H, W],
- depth1 (torch.Tensor): [N, H, W],
- T_0to1 (torch.Tensor): [N, 3, 4],
- K0 (torch.Tensor): [N, 3, 3],
- K1 (torch.Tensor): [N, 3, 3],
- Returns:
- calculable_mask (torch.Tensor): [N, L]
- warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
- """
- (
- n,
- h,
- w,
- ) = depth0.shape
- if depth_interpolation_mode == "combined":
- # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
- if smooth_mask:
- raise NotImplementedError("Combined bilinear and NN warp not implemented")
- valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
- smooth_mask = smooth_mask,
- return_relative_depth_error = return_relative_depth_error,
- depth_interpolation_mode = "bilinear",
- relative_depth_error_threshold = relative_depth_error_threshold)
- valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
- smooth_mask = smooth_mask,
- return_relative_depth_error = return_relative_depth_error,
- depth_interpolation_mode = "nearest-exact",
- relative_depth_error_threshold = relative_depth_error_threshold)
- nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
- warp = warp_bilinear.clone()
- warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
- valid = valid_bilinear | valid_nearest
- return valid, warp
-
-
- kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
- :, 0, :, 0
- ]
- kpts0 = torch.stack(
- (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
- ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
- # Sample depth, get calculable_mask on depth != 0
- nonzero_mask = kpts0_depth != 0
- # Unproject
- kpts0_h = (
- torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
- * kpts0_depth[..., None]
- ) # (N, L, 3)
- kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
- kpts0_cam = kpts0_n
- # Rigid Transform
- w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
- w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
- # Project
- w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
- w_kpts0 = w_kpts0_h[:, :, :2] / (
- w_kpts0_h[:, :, [2]] + 1e-4
- ) # (N, L, 2), +1e-4 to avoid zero depth
- # Covisible Check
- h, w = depth1.shape[1:3]
- covisible_mask = (
- (w_kpts0[:, :, 0] > 0)
- * (w_kpts0[:, :, 0] < w - 1)
- * (w_kpts0[:, :, 1] > 0)
- * (w_kpts0[:, :, 1] < h - 1)
- )
- w_kpts0 = torch.stack(
- (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
- ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
- # w_kpts0[~covisible_mask, :] = -5 # xd
- w_kpts0_depth = F.grid_sample(
- depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
- )[:, 0, :, 0]
-
- relative_depth_error = (
- (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
- ).abs()
- if not smooth_mask:
- consistent_mask = relative_depth_error < relative_depth_error_threshold
- else:
- consistent_mask = (-relative_depth_error/smooth_mask).exp()
- valid_mask = nonzero_mask * covisible_mask * consistent_mask
- if return_relative_depth_error:
- return relative_depth_error, w_kpts0
- else:
- return valid_mask, w_kpts0
- imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
- imagenet_std = torch.tensor([0.229, 0.224, 0.225])
- def numpy_to_pil(x: np.ndarray):
- """
- Args:
- x: Assumed to be of shape (h,w,c)
- """
- if isinstance(x, torch.Tensor):
- x = x.detach().cpu().numpy()
- if x.max() <= 1.01:
- x *= 255
- x = x.astype(np.uint8)
- return Image.fromarray(x)
- def tensor_to_pil(x, unnormalize=False):
- if unnormalize:
- x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
- x = x.detach().permute(1, 2, 0).cpu().numpy()
- x = np.clip(x, 0.0, 1.0)
- return numpy_to_pil(x)
- def to_cuda(batch):
- for key, value in batch.items():
- if isinstance(value, torch.Tensor):
- batch[key] = value.cuda()
- return batch
- def to_cpu(batch):
- for key, value in batch.items():
- if isinstance(value, torch.Tensor):
- batch[key] = value.cpu()
- return batch
- def get_pose(calib):
- w, h = np.array(calib["imsize"])[0]
- return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
- def compute_relative_pose(R1, t1, R2, t2):
- rots = R2 @ (R1.T)
- trans = -rots @ t1 + t2
- return rots, trans
- @torch.no_grad()
- def reset_opt(opt):
- for group in opt.param_groups:
- for p in group['params']:
- if p.requires_grad:
- state = opt.state[p]
- # State initialization
- # Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(p)
- # Exponential moving average of squared gradient values
- state['exp_avg_sq'] = torch.zeros_like(p)
- # Exponential moving average of gradient difference
- state['exp_avg_diff'] = torch.zeros_like(p)
- def flow_to_pixel_coords(flow, h1, w1):
- flow = (
- torch.stack(
- (
- w1 * (flow[..., 0] + 1) / 2,
- h1 * (flow[..., 1] + 1) / 2,
- ),
- axis=-1,
- )
- )
- return flow
- to_pixel_coords = flow_to_pixel_coords # just an alias
- def flow_to_normalized_coords(flow, h1, w1):
- flow = (
- torch.stack(
- (
- 2 * (flow[..., 0]) / w1 - 1,
- 2 * (flow[..., 1]) / h1 - 1,
- ),
- axis=-1,
- )
- )
- return flow
- to_normalized_coords = flow_to_normalized_coords # just an alias
- def warp_to_pixel_coords(warp, h1, w1, h2, w2):
- warp1 = warp[..., :2]
- warp1 = (
- torch.stack(
- (
- w1 * (warp1[..., 0] + 1) / 2,
- h1 * (warp1[..., 1] + 1) / 2,
- ),
- axis=-1,
- )
- )
- warp2 = warp[..., 2:]
- warp2 = (
- torch.stack(
- (
- w2 * (warp2[..., 0] + 1) / 2,
- h2 * (warp2[..., 1] + 1) / 2,
- ),
- axis=-1,
- )
- )
- return torch.cat((warp1,warp2), dim=-1)
- def signed_point_line_distance(point, line, eps: float = 1e-9):
- r"""Return the distance from points to lines.
- Args:
- point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`.
- line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`.
- eps: Small constant for safe sqrt.
- Returns:
- the computed distance with shape :math:`(*, N)`.
- """
- if not point.shape[-1] in (2, 3):
- raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}")
- if not line.shape[-1] == 3:
- raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
- numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
- denominator = line[..., :2].norm(dim=-1)
- return numerator / (denominator + eps)
- def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
- r"""Return one-sided epipolar distance for correspondences given the fundamental matrix.
- This method measures the distance from points in the right images to the epilines
- of the corresponding points in the left images as they reflect in the right images.
- Args:
- pts1: correspondences from the left images with shape
- :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
- pts2: correspondences from the right images with shape
- :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
- Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to
- avoid ambiguity with torch.nn.functional.
- Returns:
- the computed Symmetrical distance with shape :math:`(*, N)`.
- """
- import kornia
- if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
- raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
- if pts1.shape[-1] == 2:
- pts1 = kornia.geometry.convert_points_to_homogeneous(pts1)
- F_t = Fm.transpose(dim0=-2, dim1=-1)
- line1_in_2 = pts1 @ F_t
- return signed_point_line_distance(pts2, line1_in_2)
- def get_grid(b, h, w, device):
- grid = torch.meshgrid(
- *[
- torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
- for n in (b, h, w)
- ],
- indexing = 'ij'
- )
- grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
- return grid
- def get_autocast_params(device=None, enabled=False, dtype=None):
- if device is None:
- autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
- else:
- #strip :X from device
- autocast_device = str(device).split(":")[0]
- if 'cuda' in str(device):
- out_dtype = dtype
- enabled = True
- else:
- out_dtype = torch.bfloat16
- enabled = False
- # mps is not supported
- autocast_device = "cpu"
- return autocast_device, enabled, out_dtype
- def check_not_i16(im):
- if im.mode == "I;16":
- raise NotImplementedError("Can't handle 16 bit images")
- def check_rgb(im):
- if im.mode != "RGB":
- raise NotImplementedError("Can't handle non-RGB images")
|