param_validation.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, List, Optional, Tuple, Union
  18. import torch
  19. from kornia.core import Tensor, as_tensor, tensor
  20. def _common_param_check(batch_size: int, same_on_batch: Optional[bool] = None) -> None:
  21. """Check valid batch_size and same_on_batch params."""
  22. if not (isinstance(batch_size, int) and batch_size >= 0):
  23. raise AssertionError(f"`batch_size` shall be a positive integer. Got {batch_size}.")
  24. if same_on_batch is not None and not isinstance(same_on_batch, bool):
  25. raise AssertionError(f"`same_on_batch` shall be boolean. Got {same_on_batch}.")
  26. def _range_bound(
  27. factor: Union[Tensor, float, Tuple[float, float], List[float]],
  28. name: str,
  29. center: Optional[float] = 0.0,
  30. bounds: Optional[Tuple[float, float]] = (0, float("inf")),
  31. check: Optional[str] = "joint",
  32. device: Optional[torch.device] = None,
  33. dtype: Optional[torch.dtype] = None,
  34. ) -> Tensor:
  35. r"""Check inputs and compute the corresponding factor bounds."""
  36. if device is None:
  37. device = torch.device("cpu")
  38. if dtype is None:
  39. dtype = torch.get_default_dtype()
  40. if not isinstance(factor, (Tensor)):
  41. factor = tensor(factor, device=device, dtype=dtype)
  42. factor_bound: Tensor
  43. if factor.dim() == 0:
  44. if factor < 0:
  45. raise ValueError(f"If {name} is a single number, it must be non negative. Got {factor}.")
  46. if center is None or bounds is None:
  47. raise ValueError(f"`center` and `bounds` cannot be None for single number. Got {center}, {bounds}.")
  48. # Should be something other than clamp
  49. # Currently, single value factor will not out of scope as long as the user provided it.
  50. # Note: I personally think throw an error will be better than a coarse clamp.
  51. factor_bound = factor.repeat(2) * tensor([-1.0, 1.0], device=factor.device, dtype=factor.dtype) + center
  52. factor_bound = factor_bound.clamp(bounds[0], bounds[1]).to(device=device, dtype=dtype)
  53. else:
  54. factor_bound = as_tensor(factor, device=device, dtype=dtype)
  55. if check is not None:
  56. if check == "joint":
  57. _joint_range_check(factor_bound, name, bounds)
  58. elif check == "singular":
  59. _singular_range_check(factor_bound, name, bounds)
  60. else:
  61. raise NotImplementedError(f"methods '{check}' not implemented.")
  62. return factor_bound
  63. def _joint_range_check(ranged_factor: Tensor, name: str, bounds: Optional[Tuple[float, float]] = None) -> None:
  64. """Check if bounds[0] <= ranged_factor[0] <= ranged_factor[1] <= bounds[1]."""
  65. if bounds is None:
  66. bounds = (float("-inf"), float("inf"))
  67. if ranged_factor.dim() == 1 and len(ranged_factor) == 2:
  68. if not bounds[0] <= ranged_factor[0] or not bounds[1] >= ranged_factor[1]:
  69. raise ValueError(f"{name} out of bounds. Expected inside {bounds}, got {ranged_factor}.")
  70. if not bounds[0] <= ranged_factor[0] <= ranged_factor[1] <= bounds[1]:
  71. raise ValueError(f"{name}[0] should be smaller than {name}[1] got {ranged_factor}")
  72. else:
  73. raise TypeError(f"{name} should be a tensor with length 2 whose values between {bounds}. Got {ranged_factor}.")
  74. def _singular_range_check(
  75. ranged_factor: Tensor,
  76. name: str,
  77. bounds: Optional[Tuple[float, float]] = None,
  78. skip_none: bool = False,
  79. mode: str = "2d",
  80. ) -> None:
  81. """Check if bounds[0] <= ranged_factor[0] <= bounds[1] and bounds[0] <= ranged_factor[1] <= bounds[1]."""
  82. if mode == "2d":
  83. dim_size = 2
  84. elif mode == "3d":
  85. dim_size = 3
  86. else:
  87. raise ValueError(f"'mode' shall be either 2d or 3d. Got {mode}")
  88. if skip_none and ranged_factor is None:
  89. return
  90. if bounds is None:
  91. bounds = (float("-inf"), float("inf"))
  92. if ranged_factor.dim() == 1 and len(ranged_factor) == dim_size:
  93. for f in ranged_factor:
  94. if not bounds[0] <= f <= bounds[1]:
  95. raise ValueError(f"{name} out of bounds. Expected inside {bounds}, got {ranged_factor}.")
  96. else:
  97. raise TypeError(
  98. f"{name} should be a float number or a tuple with length {dim_size} whose values between {bounds}."
  99. f"Got {ranged_factor}"
  100. )
  101. def _tuple_range_reader(
  102. input_range: Union[Tensor, float, Tuple[Any, ...]],
  103. target_size: int,
  104. device: Optional[torch.device] = None,
  105. dtype: Optional[torch.dtype] = None,
  106. ) -> Tensor:
  107. """Given target_size, it will generate the corresponding (target_size, 2) range tensor for element-wise params.
  108. Example:
  109. >>> degree = tensor([0.2, 0.3])
  110. >>> _tuple_range_reader(degree, 3) # read degree for yaw, pitch and roll.
  111. tensor([[0.2000, 0.3000],
  112. [0.2000, 0.3000],
  113. [0.2000, 0.3000]])
  114. """
  115. target_shape = torch.Size([target_size, 2])
  116. if isinstance(input_range, Tensor):
  117. if (len(input_range.shape) == 0) or (len(input_range.shape) == 1 and len(input_range) == 1):
  118. if input_range < 0:
  119. raise ValueError(f"If input_range is only one number it must be a positive number. Got{input_range}")
  120. input_range_tmp = input_range.repeat(2).to(device=device, dtype=dtype) * tensor(
  121. [-1, 1], device=device, dtype=dtype
  122. )
  123. input_range_tmp = input_range_tmp.repeat(target_shape[0], 1)
  124. elif len(input_range.shape) == 1 and len(input_range) == 2:
  125. input_range_tmp = input_range.repeat(target_shape[0], 1).to(device=device, dtype=dtype)
  126. elif len(input_range.shape) == 1 and len(input_range) == target_shape[0]:
  127. input_range_tmp = input_range.unsqueeze(1).repeat(1, 2).to(device=device, dtype=dtype) * tensor(
  128. [-1, 1], device=device, dtype=dtype
  129. )
  130. elif input_range.shape == target_shape:
  131. input_range_tmp = input_range.to(device=device, dtype=dtype)
  132. else:
  133. raise ValueError(
  134. f"Degrees must be a {list(target_shape)} tensor for the degree range for independent operation."
  135. f"Got {input_range}"
  136. )
  137. elif isinstance(input_range, (float, int)):
  138. if input_range < 0:
  139. raise ValueError(f"If input_range is only one number it must be a positive number. Got{input_range}")
  140. input_range_tmp = tensor([-input_range, input_range], device=device, dtype=dtype).repeat(target_shape[0], 1)
  141. elif (
  142. isinstance(input_range, (tuple, list))
  143. and len(input_range) == 2
  144. and isinstance(input_range[0], (float, int))
  145. and isinstance(input_range[1], (float, int))
  146. ):
  147. input_range_tmp = tensor(input_range, device=device, dtype=dtype).repeat(target_shape[0], 1)
  148. elif (
  149. isinstance(input_range, (tuple, list))
  150. and len(input_range) == target_shape[0]
  151. and all(isinstance(x, (float, int)) for x in input_range)
  152. ):
  153. input_range_tmp = tensor([(-s, s) for s in input_range], device=device, dtype=dtype)
  154. elif (
  155. isinstance(input_range, (tuple, list))
  156. and len(input_range) == target_shape[0]
  157. and all(isinstance(x, (tuple, list)) for x in input_range)
  158. ):
  159. input_range_tmp = tensor(input_range, device=device, dtype=dtype)
  160. else:
  161. raise TypeError(
  162. "If not pass a tensor, it must be float, (float, float) for isotropic operation or a tuple of "
  163. f"{target_size} floats or {target_size} (float, float) for independent operation. Got {input_range}."
  164. )
  165. return input_range_tmp