vector.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Optional, Tuple, Union, cast
  18. from kornia.core import Device, Dtype, Tensor, as_tensor, normalize, rand, stack
  19. from kornia.core.check import KORNIA_CHECK
  20. from kornia.core.tensor_wrapper import TensorWrapper, wrap # type: ignore[attr-defined]
  21. from kornia.geometry.linalg import batched_dot_product, batched_squared_norm
  22. __all__ = ["Scalar", "Vector2", "Vector3"]
  23. # TODO: implement more functionality to validate
  24. class Scalar(TensorWrapper):
  25. def __init__(self, data: Tensor) -> None:
  26. super().__init__(data)
  27. class Vector3(TensorWrapper):
  28. def __init__(self, vector: Tensor) -> None:
  29. super().__init__(vector)
  30. KORNIA_CHECK(vector.shape[-1] == 3)
  31. def __repr__(self) -> str:
  32. return f"x: {self.x}\ny: {self.y}\nz: {self.z}"
  33. def __getitem__(self, idx: Union[slice, int, Tensor]) -> "Vector3":
  34. return Vector3(self.data[idx, ...])
  35. @property
  36. def x(self) -> Tensor:
  37. return self.data[..., 0]
  38. @property
  39. def y(self) -> Tensor:
  40. return self.data[..., 1]
  41. @property
  42. def z(self) -> Tensor:
  43. return self.data[..., 2]
  44. def normalized(self) -> "Vector3":
  45. return Vector3(normalize(self.data, p=2, dim=-1))
  46. def dot(self, right: "Vector3") -> Scalar:
  47. return Scalar(batched_dot_product(self.data, right.data))
  48. def squared_norm(self) -> Scalar:
  49. return Scalar(batched_squared_norm(self.data))
  50. @classmethod
  51. def random(
  52. cls, shape: Optional[Tuple[int, ...]] = None, device: Optional[Device] = None, dtype: Dtype = None
  53. ) -> "Vector3":
  54. if shape is None:
  55. shape = ()
  56. return cls(rand((*shape, 3), device=device, dtype=dtype))
  57. # TODO: polish overload
  58. # @overload
  59. # @classmethod
  60. # def from_coords(
  61. # cls, x: Tensor, y: Tensor, z: Tensor, device=None, dtype=None
  62. # ) -> "Vector3":
  63. # KORNIA_CHECK(isinstance(x, Tensor))
  64. # KORNIA_CHECK(type(x) == type(y) == type(z))
  65. # return wrap(as_tensor((x, y, z), device=device, dtype=dtype), Vector3)
  66. # TODO: polish overload
  67. # @overload
  68. # @classmethod
  69. # def from_coords(
  70. # cls, x: float, y: float, z: float, device=None, dtype=None
  71. # ) -> "Vector3":
  72. # KORNIA_CHECK(isinstance(x, float))
  73. # KORNIA_CHECK(type(x) == type(y) == type(z))
  74. # return wrap(as_tensor((x, y, z), device=device, dtype=dtype), Vector3)
  75. @classmethod
  76. def from_coords(
  77. cls,
  78. x: Union[float, Tensor],
  79. y: Union[float, Tensor],
  80. z: Union[float, Tensor],
  81. device: Optional[Device] = None,
  82. dtype: Dtype = None,
  83. ) -> "Vector3":
  84. KORNIA_CHECK(type(x) is type(y) is type(z))
  85. KORNIA_CHECK(isinstance(x, (Tensor, float)))
  86. if isinstance(x, float):
  87. return wrap(as_tensor((x, y, z), device=device, dtype=dtype), Vector3)
  88. # TODO: this is totally insane ...
  89. tensors: Tuple[Tensor, ...] = (x, cast(Tensor, y), cast(Tensor, z))
  90. return wrap(stack(tensors, -1), Vector3)
  91. class Vector2(TensorWrapper):
  92. def __init__(self, vector: Tensor) -> None:
  93. super().__init__(vector)
  94. KORNIA_CHECK(vector.shape[-1] == 2)
  95. def __repr__(self) -> str:
  96. return f"x: {self.x}\ny: {self.y}"
  97. def __getitem__(self, idx: Union[slice, int, Tensor]) -> "Vector2":
  98. return Vector2(self.data[idx, ...])
  99. @property
  100. def x(self) -> Tensor:
  101. return self.data[..., 0]
  102. @property
  103. def y(self) -> Tensor:
  104. return self.data[..., 1]
  105. def normalized(self) -> "Vector2":
  106. return Vector2(normalize(self.data, p=2, dim=-1))
  107. def dot(self, right: "Vector2") -> Scalar:
  108. return Scalar(batched_dot_product(self.data, right.data))
  109. def squared_norm(self) -> Scalar:
  110. return Scalar(batched_squared_norm(self.data))
  111. @classmethod
  112. def random(cls, shape: Optional[Tuple[int, ...]] = None, device: Device = None, dtype: Dtype = None) -> "Vector2":
  113. if shape is None:
  114. shape = ()
  115. return cls(rand((*shape, 2), device=device, dtype=dtype))
  116. @classmethod
  117. def from_coords(
  118. cls, x: Union[float, Tensor], y: Union[float, Tensor], device: Device = None, dtype: Dtype = None
  119. ) -> "Vector2":
  120. KORNIA_CHECK(type(x) is type(y))
  121. KORNIA_CHECK(isinstance(x, (Tensor, float)))
  122. if isinstance(x, float):
  123. return wrap(as_tensor((x, y), device=device, dtype=dtype), Vector2)
  124. # TODO: this is totally insane ...
  125. tensors: Tuple[Tensor, ...] = (x, cast(Tensor, y))
  126. return wrap(stack(tensors, -1), Vector2)
  127. Vec3 = Vector3
  128. Vec2 = Vector2
  129. # TODO: adapt to TensorWrapper
  130. # class UnitVector(Module):
  131. # def __init__(self, vector: Tensor) -> None:
  132. # super().__init__()
  133. # KORNIA_CHECK_SHAPE(vector, ["B", "N"])
  134. # self._vector = Parameter(vector)
  135. #
  136. # @property
  137. # def vector(self) -> Tensor:
  138. # return self._vector
  139. #
  140. # @classmethod
  141. # def from_unit_vector(cls, v: Tensor) -> "UnitVector":
  142. # # TODO: add checks https://github.com/strasdat/Sophus/blob/23.04-beta/cpp/sophus/geometry/ray.h#L59
  143. # return UnitVector(_VectorType(v))
  144. #
  145. # @classmethod
  146. # def from_vector(cls, v: Tensor) -> "UnitVector":
  147. # """From a vector and normalize."""
  148. # return UnitVector(_VectorType(v).normalized())
  149. #