_validators.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. """Validator functions for standard library types.
  2. Import of this module is deferred since it contains imports of many standard library modules.
  3. """
  4. from __future__ import annotations as _annotations
  5. import collections.abc
  6. import math
  7. import re
  8. import typing
  9. from collections.abc import Sequence
  10. from decimal import Decimal
  11. from fractions import Fraction
  12. from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
  13. from typing import Any, Callable, TypeVar, Union, cast
  14. from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
  15. import typing_extensions
  16. from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema
  17. from typing_extensions import get_args, get_origin
  18. from typing_inspection import typing_objects
  19. from pydantic._internal._import_utils import import_cached_field_info
  20. from pydantic.errors import PydanticSchemaGenerationError
  21. def sequence_validator(
  22. input_value: Sequence[Any],
  23. /,
  24. validator: core_schema.ValidatorFunctionWrapHandler,
  25. ) -> Sequence[Any]:
  26. """Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
  27. value_type = type(input_value)
  28. # We don't accept any plain string as a sequence
  29. # Relevant issue: https://github.com/pydantic/pydantic/issues/5595
  30. if issubclass(value_type, (str, bytes)):
  31. raise PydanticCustomError(
  32. 'sequence_str',
  33. "'{type_name}' instances are not allowed as a Sequence value",
  34. {'type_name': value_type.__name__},
  35. )
  36. # TODO: refactor sequence validation to validate with either a list or a tuple
  37. # schema, depending on the type of the value.
  38. # Additionally, we should be able to remove one of either this validator or the
  39. # SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
  40. # Effectively, a refactor for sequence validation is needed.
  41. if value_type is tuple:
  42. input_value = list(input_value)
  43. v_list = validator(input_value)
  44. # the rest of the logic is just re-creating the original type from `v_list`
  45. if value_type is list:
  46. return v_list
  47. elif issubclass(value_type, range):
  48. # return the list as we probably can't re-create the range
  49. return v_list
  50. elif value_type is tuple:
  51. return tuple(v_list)
  52. else:
  53. # best guess at how to re-create the original type, more custom construction logic might be required
  54. return value_type(v_list) # type: ignore[call-arg]
  55. def import_string(value: Any) -> Any:
  56. if isinstance(value, str):
  57. try:
  58. return _import_string_logic(value)
  59. except ImportError as e:
  60. raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
  61. else:
  62. # otherwise we just return the value and let the next validator do the rest of the work
  63. return value
  64. def _import_string_logic(dotted_path: str) -> Any:
  65. """Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module.
  66. (This is necessary to distinguish between a submodule and an attribute when there is a conflict.).
  67. If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute
  68. rather than a submodule will be attempted automatically.
  69. So, for example, the following values of `dotted_path` result in the following returned values:
  70. * 'collections': <module 'collections'>
  71. * 'collections.abc': <module 'collections.abc'>
  72. * 'collections.abc:Mapping': <class 'collections.abc.Mapping'>
  73. * `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line)
  74. An error will be raised under any of the following scenarios:
  75. * `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping')
  76. * the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping')
  77. * the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123')
  78. """
  79. from importlib import import_module
  80. components = dotted_path.strip().split(':')
  81. if len(components) > 2:
  82. raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}")
  83. module_path = components[0]
  84. if not module_path:
  85. raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}')
  86. try:
  87. module = import_module(module_path)
  88. except ModuleNotFoundError as e:
  89. if '.' in module_path:
  90. # Check if it would be valid if the final item was separated from its module with a `:`
  91. maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1)
  92. try:
  93. return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}')
  94. except ImportError:
  95. pass
  96. raise ImportError(f'No module named {module_path!r}') from e
  97. raise e
  98. if len(components) > 1:
  99. attribute = components[1]
  100. try:
  101. return getattr(module, attribute)
  102. except AttributeError as e:
  103. raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e
  104. else:
  105. return module
  106. def pattern_either_validator(input_value: Any, /) -> re.Pattern[Any]:
  107. if isinstance(input_value, re.Pattern):
  108. return input_value
  109. elif isinstance(input_value, (str, bytes)):
  110. # todo strict mode
  111. return compile_pattern(input_value) # type: ignore
  112. else:
  113. raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
  114. def pattern_str_validator(input_value: Any, /) -> re.Pattern[str]:
  115. if isinstance(input_value, re.Pattern):
  116. if isinstance(input_value.pattern, str):
  117. return input_value
  118. else:
  119. raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
  120. elif isinstance(input_value, str):
  121. return compile_pattern(input_value)
  122. elif isinstance(input_value, bytes):
  123. raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
  124. else:
  125. raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
  126. def pattern_bytes_validator(input_value: Any, /) -> re.Pattern[bytes]:
  127. if isinstance(input_value, re.Pattern):
  128. if isinstance(input_value.pattern, bytes):
  129. return input_value
  130. else:
  131. raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
  132. elif isinstance(input_value, bytes):
  133. return compile_pattern(input_value)
  134. elif isinstance(input_value, str):
  135. raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
  136. else:
  137. raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
  138. PatternType = TypeVar('PatternType', str, bytes)
  139. def compile_pattern(pattern: PatternType) -> re.Pattern[PatternType]:
  140. try:
  141. return re.compile(pattern)
  142. except re.error:
  143. raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
  144. def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
  145. if isinstance(input_value, IPv4Address):
  146. return input_value
  147. try:
  148. return IPv4Address(input_value)
  149. except ValueError:
  150. raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
  151. def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
  152. if isinstance(input_value, IPv6Address):
  153. return input_value
  154. try:
  155. return IPv6Address(input_value)
  156. except ValueError:
  157. raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
  158. def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
  159. """Assume IPv4Network initialised with a default `strict` argument.
  160. See more:
  161. https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
  162. """
  163. if isinstance(input_value, IPv4Network):
  164. return input_value
  165. try:
  166. return IPv4Network(input_value)
  167. except ValueError:
  168. raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
  169. def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
  170. """Assume IPv6Network initialised with a default `strict` argument.
  171. See more:
  172. https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
  173. """
  174. if isinstance(input_value, IPv6Network):
  175. return input_value
  176. try:
  177. return IPv6Network(input_value)
  178. except ValueError:
  179. raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
  180. def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
  181. if isinstance(input_value, IPv4Interface):
  182. return input_value
  183. try:
  184. return IPv4Interface(input_value)
  185. except ValueError:
  186. raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
  187. def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
  188. if isinstance(input_value, IPv6Interface):
  189. return input_value
  190. try:
  191. return IPv6Interface(input_value)
  192. except ValueError:
  193. raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
  194. def fraction_validator(input_value: Any, /) -> Fraction:
  195. if isinstance(input_value, Fraction):
  196. return input_value
  197. try:
  198. return Fraction(input_value)
  199. except ValueError:
  200. raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
  201. def forbid_inf_nan_check(x: Any) -> Any:
  202. if not math.isfinite(x):
  203. raise PydanticKnownError('finite_number')
  204. return x
  205. def _safe_repr(v: Any) -> int | float | str:
  206. """The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
  207. See tests/test_types.py::test_annotated_metadata_any_order for some context.
  208. """
  209. if isinstance(v, (int, float, str)):
  210. return v
  211. return repr(v)
  212. def greater_than_validator(x: Any, gt: Any) -> Any:
  213. try:
  214. if not (x > gt):
  215. raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
  216. return x
  217. except TypeError:
  218. raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
  219. def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
  220. try:
  221. if not (x >= ge):
  222. raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
  223. return x
  224. except TypeError:
  225. raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
  226. def less_than_validator(x: Any, lt: Any) -> Any:
  227. try:
  228. if not (x < lt):
  229. raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
  230. return x
  231. except TypeError:
  232. raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
  233. def less_than_or_equal_validator(x: Any, le: Any) -> Any:
  234. try:
  235. if not (x <= le):
  236. raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
  237. return x
  238. except TypeError:
  239. raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
  240. def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
  241. try:
  242. if x % multiple_of:
  243. raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
  244. return x
  245. except TypeError:
  246. raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
  247. def min_length_validator(x: Any, min_length: Any) -> Any:
  248. try:
  249. if not (len(x) >= min_length):
  250. raise PydanticKnownError(
  251. 'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
  252. )
  253. return x
  254. except TypeError:
  255. raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
  256. def max_length_validator(x: Any, max_length: Any) -> Any:
  257. try:
  258. if len(x) > max_length:
  259. raise PydanticKnownError(
  260. 'too_long',
  261. {'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
  262. )
  263. return x
  264. except TypeError:
  265. raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
  266. def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
  267. """Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
  268. This function handles both normalized and non-normalized Decimal instances.
  269. Example: Decimal('1.230') -> 4 digits, 3 decimal places
  270. Args:
  271. decimal (Decimal): The decimal number to analyze.
  272. Returns:
  273. tuple[int, int]: A tuple containing the number of decimal places and total digits.
  274. Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
  275. of the number of decimals and digits together.
  276. """
  277. try:
  278. decimal_tuple = decimal.as_tuple()
  279. assert isinstance(decimal_tuple.exponent, int)
  280. exponent = decimal_tuple.exponent
  281. num_digits = len(decimal_tuple.digits)
  282. if exponent >= 0:
  283. # A positive exponent adds that many trailing zeros
  284. # Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
  285. num_digits += exponent
  286. decimal_places = 0
  287. else:
  288. # If the absolute value of the negative exponent is larger than the
  289. # number of digits, then it's the same as the number of digits,
  290. # because it'll consume all the digits in digit_tuple and then
  291. # add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
  292. # Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
  293. # Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
  294. decimal_places = abs(exponent)
  295. num_digits = max(num_digits, decimal_places)
  296. return decimal_places, num_digits
  297. except (AssertionError, AttributeError):
  298. raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
  299. def max_digits_validator(x: Any, max_digits: Any) -> Any:
  300. try:
  301. _, num_digits = _extract_decimal_digits_info(x)
  302. _, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
  303. if (num_digits > max_digits) and (normalized_num_digits > max_digits):
  304. raise PydanticKnownError(
  305. 'decimal_max_digits',
  306. {'max_digits': max_digits},
  307. )
  308. return x
  309. except TypeError:
  310. raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
  311. def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
  312. try:
  313. decimal_places_, _ = _extract_decimal_digits_info(x)
  314. if decimal_places_ > decimal_places:
  315. normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
  316. if normalized_decimal_places > decimal_places:
  317. raise PydanticKnownError(
  318. 'decimal_max_places',
  319. {'decimal_places': decimal_places},
  320. )
  321. return x
  322. except TypeError:
  323. raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
  324. def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]:
  325. return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None))
  326. def defaultdict_validator(
  327. input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
  328. ) -> collections.defaultdict[Any, Any]:
  329. if isinstance(input_value, collections.defaultdict):
  330. default_factory = input_value.default_factory
  331. return collections.defaultdict(default_factory, handler(input_value))
  332. else:
  333. return collections.defaultdict(default_default_factory, handler(input_value))
  334. def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
  335. FieldInfo = import_cached_field_info()
  336. values_type_origin = get_origin(values_source_type)
  337. def infer_default() -> Callable[[], Any]:
  338. allowed_default_types: dict[Any, Any] = {
  339. tuple: tuple,
  340. collections.abc.Sequence: tuple,
  341. collections.abc.MutableSequence: list,
  342. list: list,
  343. typing.Sequence: list,
  344. set: set,
  345. typing.MutableSet: set,
  346. collections.abc.MutableSet: set,
  347. collections.abc.Set: frozenset,
  348. typing.MutableMapping: dict,
  349. typing.Mapping: dict,
  350. collections.abc.Mapping: dict,
  351. collections.abc.MutableMapping: dict,
  352. float: float,
  353. int: int,
  354. str: str,
  355. bool: bool,
  356. }
  357. values_type = values_type_origin or values_source_type
  358. instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
  359. if typing_objects.is_typevar(values_type):
  360. def type_var_default_factory() -> None:
  361. raise RuntimeError(
  362. 'Generic defaultdict cannot be used without a concrete value type or an'
  363. ' explicit default factory, ' + instructions
  364. )
  365. return type_var_default_factory
  366. elif values_type not in allowed_default_types:
  367. # a somewhat subjective set of types that have reasonable default values
  368. allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
  369. raise PydanticSchemaGenerationError(
  370. f'Unable to infer a default factory for keys of type {values_source_type}.'
  371. f' Only {allowed_msg} are supported, other types require an explicit default factory'
  372. ' ' + instructions
  373. )
  374. return allowed_default_types[values_type]
  375. # Assume Annotated[..., Field(...)]
  376. if typing_objects.is_annotated(values_type_origin):
  377. field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
  378. else:
  379. field_info = None
  380. if field_info and field_info.default_factory:
  381. # Assume the default factory does not take any argument:
  382. default_default_factory = cast(Callable[[], Any], field_info.default_factory)
  383. else:
  384. default_default_factory = infer_default()
  385. return default_default_factory
  386. def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo:
  387. if isinstance(value, ZoneInfo):
  388. return value
  389. try:
  390. return ZoneInfo(value)
  391. except (ZoneInfoNotFoundError, ValueError, TypeError):
  392. raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value})
  393. NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
  394. 'gt': greater_than_validator,
  395. 'ge': greater_than_or_equal_validator,
  396. 'lt': less_than_validator,
  397. 'le': less_than_or_equal_validator,
  398. 'multiple_of': multiple_of_validator,
  399. 'min_length': min_length_validator,
  400. 'max_length': max_length_validator,
  401. 'max_digits': max_digits_validator,
  402. 'decimal_places': decimal_places_validator,
  403. }
  404. IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
  405. IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
  406. IPv4Address: ip_v4_address_validator,
  407. IPv6Address: ip_v6_address_validator,
  408. IPv4Network: ip_v4_network_validator,
  409. IPv6Network: ip_v6_network_validator,
  410. IPv4Interface: ip_v4_interface_validator,
  411. IPv6Interface: ip_v6_interface_validator,
  412. }
  413. MAPPING_ORIGIN_MAP: dict[Any, Any] = {
  414. typing.DefaultDict: collections.defaultdict, # noqa: UP006
  415. collections.defaultdict: collections.defaultdict,
  416. typing.OrderedDict: collections.OrderedDict, # noqa: UP006
  417. collections.OrderedDict: collections.OrderedDict,
  418. typing_extensions.OrderedDict: collections.OrderedDict,
  419. typing.Counter: collections.Counter,
  420. collections.Counter: collections.Counter,
  421. # this doesn't handle subclasses of these
  422. typing.Mapping: dict,
  423. typing.MutableMapping: dict,
  424. # parametrized typing.{Mutable}Mapping creates one of these
  425. collections.abc.Mapping: dict,
  426. collections.abc.MutableMapping: dict,
  427. }