| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from kornia.core import Tensor
- # TODO: Temporary shape check functions until KORNIA_CHECK_SHAPE is ready
- def check_so2_z_shape(z: Tensor) -> None:
- z_shape = z.shape
- len_z_shape = len(z_shape)
- if (len_z_shape == 2 and z_shape[1] != 1) or (len_z_shape == 0 and not z.numel()) or (len_z_shape > 2):
- raise ValueError(f"Invalid input size, we expect [B]. Got: {z.shape}")
- def check_so2_t_shape(t: Tensor) -> None:
- t_shape = t.shape
- len_t_shape = len(t_shape)
- if ((len_t_shape == 2) and (t_shape[1] != 2)) or ((len_t_shape == 1) and (t_shape[0] != 2)) or (len_t_shape > 2):
- raise ValueError(f"Invalid translation shape, we expect [B, 2], or [2] Got: {t_shape}")
- def check_so2_theta_shape(theta: Tensor) -> None:
- theta_shape = theta.shape
- len_theta_shape = len(theta_shape)
- if (
- (len_theta_shape == 2 and theta_shape[1] != 1)
- or (len_theta_shape == 0 and not theta.numel())
- or (len_theta_shape > 2)
- ):
- raise ValueError(f"Invalid input size, we expect [B]. Got: {theta_shape}")
- def check_so2_matrix_shape(matrix: Tensor) -> None:
- matrix_shape = matrix.shape
- len_matrix_shape = len(matrix_shape)
- if (
- (len_matrix_shape == 3 and (matrix_shape[1] != 2 or matrix_shape[2] != 2))
- or (len_matrix_shape == 2 and (matrix_shape[0] != 2 or matrix_shape[1] != 2))
- or (len_matrix_shape > 3 or len_matrix_shape < 2)
- ):
- raise ValueError(f"Invalid input size, we expect [B, 2, 2] or [2, 2]. Got: {matrix_shape}")
- def check_so2_matrix(matrix: Tensor) -> None:
- for m in matrix.reshape(-1, 2, 2):
- if m[0, 0] != m[1, 1] or m[0, 1] != -m[1, 0]:
- raise ValueError("Invalid rotation matrix")
- def check_se2_t_shape(t: Tensor) -> None:
- check_so2_t_shape(t)
- def check_v_shape(v: Tensor) -> None:
- v_shape = v.shape
- len_v_shape = len(v_shape)
- if ((len_v_shape == 2) and (v_shape[1] != 3)) or ((len_v_shape == 1) and (v_shape[0] != 3)) or (len_v_shape > 3):
- raise ValueError(f"Invalid input shape, we expect [B, 3], [3] Got: {v_shape}")
- def check_se2_omega_shape(matrix: Tensor) -> None:
- matrix_shape = matrix.shape
- len_matrix_shape = len(matrix_shape)
- if (
- (len_matrix_shape == 3 and (matrix_shape[1] != 3 or matrix_shape[2] != 3))
- or (len_matrix_shape == 2 and (matrix_shape[0] != 3 or matrix_shape[1] != 3))
- or (len_matrix_shape > 3 or len_matrix_shape < 2)
- ):
- raise ValueError(f"Invalid input size, we expect [B, 3, 3] or [3, 3]. Got: {matrix_shape}")
|