_utils.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from kornia.core import Tensor
  18. # TODO: Temporary shape check functions until KORNIA_CHECK_SHAPE is ready
  19. def check_so2_z_shape(z: Tensor) -> None:
  20. z_shape = z.shape
  21. len_z_shape = len(z_shape)
  22. if (len_z_shape == 2 and z_shape[1] != 1) or (len_z_shape == 0 and not z.numel()) or (len_z_shape > 2):
  23. raise ValueError(f"Invalid input size, we expect [B]. Got: {z.shape}")
  24. def check_so2_t_shape(t: Tensor) -> None:
  25. t_shape = t.shape
  26. len_t_shape = len(t_shape)
  27. if ((len_t_shape == 2) and (t_shape[1] != 2)) or ((len_t_shape == 1) and (t_shape[0] != 2)) or (len_t_shape > 2):
  28. raise ValueError(f"Invalid translation shape, we expect [B, 2], or [2] Got: {t_shape}")
  29. def check_so2_theta_shape(theta: Tensor) -> None:
  30. theta_shape = theta.shape
  31. len_theta_shape = len(theta_shape)
  32. if (
  33. (len_theta_shape == 2 and theta_shape[1] != 1)
  34. or (len_theta_shape == 0 and not theta.numel())
  35. or (len_theta_shape > 2)
  36. ):
  37. raise ValueError(f"Invalid input size, we expect [B]. Got: {theta_shape}")
  38. def check_so2_matrix_shape(matrix: Tensor) -> None:
  39. matrix_shape = matrix.shape
  40. len_matrix_shape = len(matrix_shape)
  41. if (
  42. (len_matrix_shape == 3 and (matrix_shape[1] != 2 or matrix_shape[2] != 2))
  43. or (len_matrix_shape == 2 and (matrix_shape[0] != 2 or matrix_shape[1] != 2))
  44. or (len_matrix_shape > 3 or len_matrix_shape < 2)
  45. ):
  46. raise ValueError(f"Invalid input size, we expect [B, 2, 2] or [2, 2]. Got: {matrix_shape}")
  47. def check_so2_matrix(matrix: Tensor) -> None:
  48. for m in matrix.reshape(-1, 2, 2):
  49. if m[0, 0] != m[1, 1] or m[0, 1] != -m[1, 0]:
  50. raise ValueError("Invalid rotation matrix")
  51. def check_se2_t_shape(t: Tensor) -> None:
  52. check_so2_t_shape(t)
  53. def check_v_shape(v: Tensor) -> None:
  54. v_shape = v.shape
  55. len_v_shape = len(v_shape)
  56. if ((len_v_shape == 2) and (v_shape[1] != 3)) or ((len_v_shape == 1) and (v_shape[0] != 3)) or (len_v_shape > 3):
  57. raise ValueError(f"Invalid input shape, we expect [B, 3], [3] Got: {v_shape}")
  58. def check_se2_omega_shape(matrix: Tensor) -> None:
  59. matrix_shape = matrix.shape
  60. len_matrix_shape = len(matrix_shape)
  61. if (
  62. (len_matrix_shape == 3 and (matrix_shape[1] != 3 or matrix_shape[2] != 3))
  63. or (len_matrix_shape == 2 and (matrix_shape[0] != 3 or matrix_shape[1] != 3))
  64. or (len_matrix_shape > 3 or len_matrix_shape < 2)
  65. ):
  66. raise ValueError(f"Invalid input size, we expect [B, 3, 3] or [3, 3]. Got: {matrix_shape}")