utils.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import math
  18. from typing import Tuple
  19. import torch
  20. from kornia.core import Tensor
  21. def arange_sequence(ranges: Tensor) -> Tensor:
  22. """Return a sequence of the ranges specified by the argument.
  23. Example:
  24. [2, 5, 1, 2] -> [0, 1, 0, 1, 2, 3, 4, 0, 0, 1]
  25. """
  26. maxcnt = torch.max(ranges).item()
  27. numuni = ranges.shape[0]
  28. complete_ranges = torch.arange(maxcnt, device=ranges.device).unsqueeze(0).expand(numuni, -1)
  29. return complete_ranges[complete_ranges < ranges.unsqueeze(-1)]
  30. def dist_matrix(d1: Tensor, d2: Tensor, is_normalized: bool = False) -> Tensor:
  31. """Distance between two tensors."""
  32. if is_normalized:
  33. return 2 - 2.0 * d1 @ d2.t()
  34. x_norm = (d1**2).sum(1).view(-1, 1)
  35. y_norm = (d2**2).sum(1).view(1, -1)
  36. # print(x_norm, y_norm)
  37. distmat = x_norm + y_norm - 2.0 * d1 @ d2.t()
  38. # distmat[torch.isnan(distmat)] = np.inf
  39. return distmat
  40. def orientation_diff(o1: Tensor, o2: Tensor) -> Tensor:
  41. """Orientation difference between two tensors."""
  42. diff = o2 - o1
  43. diff[diff < -180] += 360
  44. diff[diff >= 180] -= 360
  45. return diff
  46. def piecewise_arange(piecewise_idxer: Tensor) -> Tensor:
  47. """Count repeated indices.
  48. Example:
  49. [0, 0, 0, 3, 3, 3, 3, 1, 1, 2] -> [0, 1, 2, 0, 1, 2, 3, 0, 1, 0]
  50. """
  51. dv = piecewise_idxer.device
  52. # print(piecewise_idxer)
  53. uni: Tensor
  54. uni, counts = torch.unique_consecutive(piecewise_idxer, return_counts=True)
  55. # print(counts)
  56. maxcnt = int(torch.max(counts).item())
  57. numuni = uni.shape[0]
  58. tmp = torch.zeros(size=(numuni, maxcnt), device=dv).bool()
  59. ranges = torch.arange(maxcnt, device=dv).unsqueeze(0).expand(numuni, -1)
  60. tmp[ranges < counts.unsqueeze(-1)] = True
  61. return ranges[tmp]
  62. def batch_2x2_inv(m: Tensor, check_dets: bool = False) -> Tensor:
  63. """Returns inverse of batch of 2x2 matrices."""
  64. a = m[..., 0, 0]
  65. b = m[..., 0, 1]
  66. c = m[..., 1, 0]
  67. d = m[..., 1, 1]
  68. minv = torch.empty_like(m)
  69. det = a * d - b * c
  70. if check_dets:
  71. det[torch.abs(det) < 1e-10] = 1e-10
  72. minv[..., 0, 0] = d
  73. minv[..., 1, 1] = a
  74. minv[..., 0, 1] = -b
  75. minv[..., 1, 0] = -c
  76. return minv / det.unsqueeze(-1).unsqueeze(-1)
  77. def batch_2x2_Q(m: Tensor) -> Tensor:
  78. """Returns Q of batch of 2x2 matrices."""
  79. return batch_2x2_inv(batch_2x2_invQ(m), check_dets=True)
  80. def batch_2x2_invQ(m: Tensor) -> Tensor:
  81. """Returns inverse Q of batch of 2x2 matrices."""
  82. return m @ m.transpose(-1, -2)
  83. def batch_2x2_det(m: Tensor) -> Tensor:
  84. """Returns determinant of batch of 2x2 matrices."""
  85. a = m[..., 0, 0]
  86. b = m[..., 0, 1]
  87. c = m[..., 1, 0]
  88. d = m[..., 1, 1]
  89. return a * d - b * c
  90. def batch_2x2_ellipse(m: Tensor, *, eps: float = 0.0) -> Tuple[Tensor, Tensor]:
  91. """Returns Eigenvalues and Eigenvectors of batch of 2x2 matrices."""
  92. am = m[..., 0, 0]
  93. bm = m[..., 0, 1]
  94. cm = m[..., 1, 0]
  95. dm = m[..., 1, 1]
  96. a = am * am + bm * bm
  97. b = am * cm + bm * dm
  98. d = cm * cm + dm * dm
  99. trh = 0.5 * (a + d)
  100. diff = 0.5 * (a - d)
  101. # stable hypot
  102. sqrtdisc = torch.hypot(diff, b)
  103. e1 = trh + sqrtdisc
  104. e2 = trh - sqrtdisc
  105. if eps > 0:
  106. e1 = e1.clamp(min=eps)
  107. e2 = e2.clamp(min=eps)
  108. else:
  109. e1 = e1.clamp(min=0.0)
  110. e2 = e2.clamp(min=0.0)
  111. eigenvals = torch.stack([e1, e2], dim=-1)
  112. theta = 0.5 * torch.atan2(2.0 * b, a - d)
  113. c = torch.cos(theta)
  114. s = torch.sin(theta)
  115. ev1 = torch.stack([c, s], dim=-1) # (...,2)
  116. ev2 = torch.stack([-s, c], dim=-1) # orthogonal (...,2)
  117. eigenvecs = torch.stack([ev1, ev2], dim=-1) # (...,2,2) columns are eigenvectors
  118. return eigenvals, eigenvecs
  119. def draw_first_k_couples(k: int, rdims: Tensor, dv: torch.device) -> Tensor:
  120. """Returns first k couples.
  121. Exhaustive search over the first n samples:
  122. * n(n+1)/2 = n2/2 + n/2 couples
  123. Max n for which we can exhaustively sample with k couples:
  124. * n2/2 + n/2 = k
  125. * n = sqrt(1/4 + 2k)-1/2 = (sqrt(8k+1)-1)/2
  126. """
  127. max_exhaustive_search = int(math.sqrt(2 * k + 0.25) - 0.5)
  128. residual_search = int(k - max_exhaustive_search * (max_exhaustive_search + 1) / 2)
  129. repeats = torch.cat(
  130. [
  131. torch.arange(max_exhaustive_search, dtype=torch.long, device=dv) + 1,
  132. torch.tensor([residual_search], dtype=torch.long, device=dv),
  133. ]
  134. )
  135. idx_sequence = torch.stack([repeats.repeat_interleave(repeats), arange_sequence(repeats)], dim=-1)
  136. return torch.remainder(idx_sequence.unsqueeze(-1), rdims)
  137. def random_samples_indices(iters: int, rdims: Tensor, dv: torch.device) -> Tensor:
  138. """Randomly sample indices of tensor."""
  139. rands = torch.rand(size=(iters, 2, rdims.shape[0]), device=dv)
  140. scaled_rands = rands * (rdims - 1e-8).float()
  141. rand_samples_rel = scaled_rands.long()
  142. return rand_samples_rel