pydantic.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. """Module containing Pydantic validation utilities for Albumentations.
  2. This module provides a collection of validators and utility functions used for validating
  3. parameters in the Pydantic models throughout the Albumentations library. It includes
  4. functions for ensuring numeric ranges are valid, handling type conversions, and creating
  5. standardized validation patterns that are reused across the codebase.
  6. """
  7. from __future__ import annotations
  8. from collections.abc import Callable
  9. from typing import Annotated, TypeVar, Union, overload
  10. from pydantic.functional_validators import AfterValidator
  11. from albumentations.core.type_definitions import Number
  12. from albumentations.core.utils import to_tuple
  13. def nondecreasing(value: tuple[Number, Number]) -> tuple[Number, Number]:
  14. """Ensure a tuple of two numbers is in non-decreasing order.
  15. Args:
  16. value (tuple[Number, Number]): Tuple of two numeric values to validate.
  17. Returns:
  18. tuple[Number, Number]: The original tuple if valid.
  19. Raises:
  20. ValueError: If the first value is greater than the second value.
  21. """
  22. if not value[0] <= value[1]:
  23. raise ValueError(f"First value should be less than the second value, got {value} instead")
  24. return value
  25. def process_non_negative_range(value: tuple[float, float] | float | None) -> tuple[float, float]:
  26. """Process and validate a non-negative range.
  27. Args:
  28. value (tuple[float, float] | float | None): Value to process. Can be:
  29. - A tuple of two floats
  30. - A single float (converted to symmetric range)
  31. - None (defaults to 0)
  32. Returns:
  33. tuple[float, float]: Validated non-negative range.
  34. Raises:
  35. ValueError: If any values in the range are negative.
  36. """
  37. result = to_tuple(value if value is not None else 0, 0)
  38. if not all(x >= 0 for x in result):
  39. msg = "All values in the non negative range should be non negative"
  40. raise ValueError(msg)
  41. return result
  42. def float2int(value: tuple[float, float]) -> tuple[int, int]:
  43. """Convert a tuple of floats to a tuple of integers.
  44. Args:
  45. value (tuple[float, float]): Tuple of two float values.
  46. Returns:
  47. tuple[int, int]: Tuple of two integer values.
  48. """
  49. return int(value[0]), int(value[1])
  50. NonNegativeFloatRangeType = Annotated[
  51. Union[tuple[float, float], float],
  52. AfterValidator(process_non_negative_range),
  53. AfterValidator(nondecreasing),
  54. ]
  55. NonNegativeIntRangeType = Annotated[
  56. Union[tuple[int, int], int],
  57. AfterValidator(process_non_negative_range),
  58. AfterValidator(nondecreasing),
  59. AfterValidator(float2int),
  60. ]
  61. @overload
  62. def create_symmetric_range(value: tuple[int, int] | int) -> tuple[int, int]: ...
  63. @overload
  64. def create_symmetric_range(value: tuple[float, float] | float) -> tuple[float, float]: ...
  65. def create_symmetric_range(value: tuple[float, float] | float) -> tuple[float, float]:
  66. """Create a symmetric range around zero or use provided range.
  67. Args:
  68. value (tuple[float, float] | float): Input value, either:
  69. - A tuple of two floats (used directly)
  70. - A single float (converted to (-value, value))
  71. Returns:
  72. tuple[float, float]: Symmetric range.
  73. """
  74. return to_tuple(value)
  75. SymmetricRangeType = Annotated[Union[tuple[float, float], float], AfterValidator(create_symmetric_range)]
  76. def convert_to_1plus_range(value: tuple[float, float] | float) -> tuple[float, float]:
  77. """Convert value to a range with lower bound of 1.
  78. Args:
  79. value (tuple[float, float] | float): Input value.
  80. Returns:
  81. tuple[float, float]: Range with minimum value of at least 1.
  82. """
  83. return to_tuple(value, low=1)
  84. def convert_to_0plus_range(value: tuple[float, float] | float) -> tuple[float, float]:
  85. """Convert value to a range with lower bound of 0.
  86. Args:
  87. value (tuple[float, float] | float): Input value.
  88. Returns:
  89. tuple[float, float]: Range with minimum value of at least 0.
  90. """
  91. return to_tuple(value, low=0)
  92. def repeat_if_scalar(value: tuple[float, float] | float) -> tuple[float, float]:
  93. """Convert a scalar value to a tuple by repeating it, or return the tuple as is.
  94. Args:
  95. value (tuple[float, float] | float): Input value, either a scalar or tuple.
  96. Returns:
  97. tuple[float, float]: If input is scalar, returns (value, value), otherwise returns input unchanged.
  98. """
  99. return (value, value) if isinstance(value, (int, float)) else value
  100. T = TypeVar("T", int, float)
  101. def check_range_bounds(
  102. min_val: Number,
  103. max_val: Number | None = None,
  104. min_inclusive: bool = True,
  105. max_inclusive: bool = True,
  106. ) -> Callable[[tuple[T, ...] | None], tuple[T, ...] | None]:
  107. """Validates that all values in a tuple are within specified bounds.
  108. Args:
  109. min_val (int | float):
  110. Minimum allowed value.
  111. max_val (int | float | None):
  112. Maximum allowed value. If None, only lower bound is checked.
  113. min_inclusive (bool):
  114. If True, min_val is inclusive (>=). If False, exclusive (>).
  115. max_inclusive (bool):
  116. If True, max_val is inclusive (<=). If False, exclusive (<).
  117. Returns:
  118. Callable[[tuple[T, ...] | None], tuple[T, ...] | None]: Validator function that
  119. checks if all values in tuple are within bounds. Returns None if input is None.
  120. Raises:
  121. ValueError: If any value in tuple is outside the allowed range
  122. Examples:
  123. >>> validator = check_range_bounds(0, 1) # For [0, 1] range
  124. >>> validator((0.1, 0.5)) # Valid 2D
  125. (0.1, 0.5)
  126. >>> validator((0.1, 0.5, 0.7)) # Valid 3D
  127. (0.1, 0.5, 0.7)
  128. >>> validator((1.1, 0.5)) # Raises ValueError - outside range
  129. >>> validator = check_range_bounds(0, 1, max_inclusive=False) # For [0, 1) range
  130. >>> validator((0, 1)) # Raises ValueError - 1 not included
  131. """
  132. def validator(value: tuple[T, ...] | None) -> tuple[T, ...] | None:
  133. if value is None:
  134. return None
  135. min_op = (lambda x, y: x >= y) if min_inclusive else (lambda x, y: x > y)
  136. max_op = (lambda x, y: x <= y) if max_inclusive else (lambda x, y: x < y)
  137. if max_val is None:
  138. if not all(min_op(x, min_val) for x in value):
  139. op_symbol = ">=" if min_inclusive else ">"
  140. raise ValueError(f"All values in {value} must be {op_symbol} {min_val}")
  141. else:
  142. min_symbol = ">=" if min_inclusive else ">"
  143. max_symbol = "<=" if max_inclusive else "<"
  144. if not all(min_op(x, min_val) and max_op(x, max_val) for x in value):
  145. raise ValueError(f"All values in {value} must be {min_symbol} {min_val} and {max_symbol} {max_val}")
  146. return value
  147. return validator
  148. ZeroOneRangeType = Annotated[
  149. Union[tuple[float, float], float],
  150. AfterValidator(convert_to_0plus_range),
  151. AfterValidator(check_range_bounds(0, 1)),
  152. AfterValidator(nondecreasing),
  153. ]
  154. OnePlusFloatRangeType = Annotated[
  155. Union[tuple[float, float], float],
  156. AfterValidator(convert_to_1plus_range),
  157. AfterValidator(check_range_bounds(1, None)),
  158. ]
  159. OnePlusIntRangeType = Annotated[
  160. Union[tuple[float, float], float],
  161. AfterValidator(convert_to_1plus_range),
  162. AfterValidator(check_range_bounds(1, None)),
  163. AfterValidator(float2int),
  164. ]
  165. OnePlusIntNonDecreasingRangeType = Annotated[
  166. tuple[int, int],
  167. AfterValidator(check_range_bounds(1, None)),
  168. AfterValidator(nondecreasing),
  169. AfterValidator(float2int),
  170. ]