# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import math from typing import Tuple import torch from kornia.core import Tensor def arange_sequence(ranges: Tensor) -> Tensor: """Return a sequence of the ranges specified by the argument. Example: [2, 5, 1, 2] -> [0, 1, 0, 1, 2, 3, 4, 0, 0, 1] """ maxcnt = torch.max(ranges).item() numuni = ranges.shape[0] complete_ranges = torch.arange(maxcnt, device=ranges.device).unsqueeze(0).expand(numuni, -1) return complete_ranges[complete_ranges < ranges.unsqueeze(-1)] def dist_matrix(d1: Tensor, d2: Tensor, is_normalized: bool = False) -> Tensor: """Distance between two tensors.""" if is_normalized: return 2 - 2.0 * d1 @ d2.t() x_norm = (d1**2).sum(1).view(-1, 1) y_norm = (d2**2).sum(1).view(1, -1) # print(x_norm, y_norm) distmat = x_norm + y_norm - 2.0 * d1 @ d2.t() # distmat[torch.isnan(distmat)] = np.inf return distmat def orientation_diff(o1: Tensor, o2: Tensor) -> Tensor: """Orientation difference between two tensors.""" diff = o2 - o1 diff[diff < -180] += 360 diff[diff >= 180] -= 360 return diff def piecewise_arange(piecewise_idxer: Tensor) -> Tensor: """Count repeated indices. Example: [0, 0, 0, 3, 3, 3, 3, 1, 1, 2] -> [0, 1, 2, 0, 1, 2, 3, 0, 1, 0] """ dv = piecewise_idxer.device # print(piecewise_idxer) uni: Tensor uni, counts = torch.unique_consecutive(piecewise_idxer, return_counts=True) # print(counts) maxcnt = int(torch.max(counts).item()) numuni = uni.shape[0] tmp = torch.zeros(size=(numuni, maxcnt), device=dv).bool() ranges = torch.arange(maxcnt, device=dv).unsqueeze(0).expand(numuni, -1) tmp[ranges < counts.unsqueeze(-1)] = True return ranges[tmp] def batch_2x2_inv(m: Tensor, check_dets: bool = False) -> Tensor: """Returns inverse of batch of 2x2 matrices.""" a = m[..., 0, 0] b = m[..., 0, 1] c = m[..., 1, 0] d = m[..., 1, 1] minv = torch.empty_like(m) det = a * d - b * c if check_dets: det[torch.abs(det) < 1e-10] = 1e-10 minv[..., 0, 0] = d minv[..., 1, 1] = a minv[..., 0, 1] = -b minv[..., 1, 0] = -c return minv / det.unsqueeze(-1).unsqueeze(-1) def batch_2x2_Q(m: Tensor) -> Tensor: """Returns Q of batch of 2x2 matrices.""" return batch_2x2_inv(batch_2x2_invQ(m), check_dets=True) def batch_2x2_invQ(m: Tensor) -> Tensor: """Returns inverse Q of batch of 2x2 matrices.""" return m @ m.transpose(-1, -2) def batch_2x2_det(m: Tensor) -> Tensor: """Returns determinant of batch of 2x2 matrices.""" a = m[..., 0, 0] b = m[..., 0, 1] c = m[..., 1, 0] d = m[..., 1, 1] return a * d - b * c def batch_2x2_ellipse(m: Tensor, *, eps: float = 0.0) -> Tuple[Tensor, Tensor]: """Returns Eigenvalues and Eigenvectors of batch of 2x2 matrices.""" am = m[..., 0, 0] bm = m[..., 0, 1] cm = m[..., 1, 0] dm = m[..., 1, 1] a = am * am + bm * bm b = am * cm + bm * dm d = cm * cm + dm * dm trh = 0.5 * (a + d) diff = 0.5 * (a - d) # stable hypot sqrtdisc = torch.hypot(diff, b) e1 = trh + sqrtdisc e2 = trh - sqrtdisc if eps > 0: e1 = e1.clamp(min=eps) e2 = e2.clamp(min=eps) else: e1 = e1.clamp(min=0.0) e2 = e2.clamp(min=0.0) eigenvals = torch.stack([e1, e2], dim=-1) theta = 0.5 * torch.atan2(2.0 * b, a - d) c = torch.cos(theta) s = torch.sin(theta) ev1 = torch.stack([c, s], dim=-1) # (...,2) ev2 = torch.stack([-s, c], dim=-1) # orthogonal (...,2) eigenvecs = torch.stack([ev1, ev2], dim=-1) # (...,2,2) columns are eigenvectors return eigenvals, eigenvecs def draw_first_k_couples(k: int, rdims: Tensor, dv: torch.device) -> Tensor: """Returns first k couples. Exhaustive search over the first n samples: * n(n+1)/2 = n2/2 + n/2 couples Max n for which we can exhaustively sample with k couples: * n2/2 + n/2 = k * n = sqrt(1/4 + 2k)-1/2 = (sqrt(8k+1)-1)/2 """ max_exhaustive_search = int(math.sqrt(2 * k + 0.25) - 0.5) residual_search = int(k - max_exhaustive_search * (max_exhaustive_search + 1) / 2) repeats = torch.cat( [ torch.arange(max_exhaustive_search, dtype=torch.long, device=dv) + 1, torch.tensor([residual_search], dtype=torch.long, device=dv), ] ) idx_sequence = torch.stack([repeats.repeat_interleave(repeats), arange_sequence(repeats)], dim=-1) return torch.remainder(idx_sequence.unsqueeze(-1), rdims) def random_samples_indices(iters: int, rdims: Tensor, dv: torch.device) -> Tensor: """Randomly sample indices of tensor.""" rands = torch.rand(size=(iters, 2, rdims.shape[0]), device=dv) scaled_rands = rands * (rdims - 1e-8).float() rand_samples_rel = scaled_rands.long() return rand_samples_rel