| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533 |
- """Validator functions for standard library types.
- Import of this module is deferred since it contains imports of many standard library modules.
- """
- from __future__ import annotations as _annotations
- import collections.abc
- import math
- import re
- import typing
- from collections.abc import Sequence
- from decimal import Decimal
- from fractions import Fraction
- from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
- from typing import Any, Callable, TypeVar, Union, cast
- from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
- import typing_extensions
- from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema
- from typing_extensions import get_args, get_origin
- from typing_inspection import typing_objects
- from pydantic._internal._import_utils import import_cached_field_info
- from pydantic.errors import PydanticSchemaGenerationError
- def sequence_validator(
- input_value: Sequence[Any],
- /,
- validator: core_schema.ValidatorFunctionWrapHandler,
- ) -> Sequence[Any]:
- """Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
- value_type = type(input_value)
- # We don't accept any plain string as a sequence
- # Relevant issue: https://github.com/pydantic/pydantic/issues/5595
- if issubclass(value_type, (str, bytes)):
- raise PydanticCustomError(
- 'sequence_str',
- "'{type_name}' instances are not allowed as a Sequence value",
- {'type_name': value_type.__name__},
- )
- # TODO: refactor sequence validation to validate with either a list or a tuple
- # schema, depending on the type of the value.
- # Additionally, we should be able to remove one of either this validator or the
- # SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
- # Effectively, a refactor for sequence validation is needed.
- if value_type is tuple:
- input_value = list(input_value)
- v_list = validator(input_value)
- # the rest of the logic is just re-creating the original type from `v_list`
- if value_type is list:
- return v_list
- elif issubclass(value_type, range):
- # return the list as we probably can't re-create the range
- return v_list
- elif value_type is tuple:
- return tuple(v_list)
- else:
- # best guess at how to re-create the original type, more custom construction logic might be required
- return value_type(v_list) # type: ignore[call-arg]
- def import_string(value: Any) -> Any:
- if isinstance(value, str):
- try:
- return _import_string_logic(value)
- except ImportError as e:
- raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
- else:
- # otherwise we just return the value and let the next validator do the rest of the work
- return value
- def _import_string_logic(dotted_path: str) -> Any:
- """Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module.
- (This is necessary to distinguish between a submodule and an attribute when there is a conflict.).
- If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute
- rather than a submodule will be attempted automatically.
- So, for example, the following values of `dotted_path` result in the following returned values:
- * 'collections': <module 'collections'>
- * 'collections.abc': <module 'collections.abc'>
- * 'collections.abc:Mapping': <class 'collections.abc.Mapping'>
- * `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line)
- An error will be raised under any of the following scenarios:
- * `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping')
- * the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping')
- * the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123')
- """
- from importlib import import_module
- components = dotted_path.strip().split(':')
- if len(components) > 2:
- raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}")
- module_path = components[0]
- if not module_path:
- raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}')
- try:
- module = import_module(module_path)
- except ModuleNotFoundError as e:
- if '.' in module_path:
- # Check if it would be valid if the final item was separated from its module with a `:`
- maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1)
- try:
- return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}')
- except ImportError:
- pass
- raise ImportError(f'No module named {module_path!r}') from e
- raise e
- if len(components) > 1:
- attribute = components[1]
- try:
- return getattr(module, attribute)
- except AttributeError as e:
- raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e
- else:
- return module
- def pattern_either_validator(input_value: Any, /) -> re.Pattern[Any]:
- if isinstance(input_value, re.Pattern):
- return input_value
- elif isinstance(input_value, (str, bytes)):
- # todo strict mode
- return compile_pattern(input_value) # type: ignore
- else:
- raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
- def pattern_str_validator(input_value: Any, /) -> re.Pattern[str]:
- if isinstance(input_value, re.Pattern):
- if isinstance(input_value.pattern, str):
- return input_value
- else:
- raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
- elif isinstance(input_value, str):
- return compile_pattern(input_value)
- elif isinstance(input_value, bytes):
- raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
- else:
- raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
- def pattern_bytes_validator(input_value: Any, /) -> re.Pattern[bytes]:
- if isinstance(input_value, re.Pattern):
- if isinstance(input_value.pattern, bytes):
- return input_value
- else:
- raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
- elif isinstance(input_value, bytes):
- return compile_pattern(input_value)
- elif isinstance(input_value, str):
- raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
- else:
- raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
- PatternType = TypeVar('PatternType', str, bytes)
- def compile_pattern(pattern: PatternType) -> re.Pattern[PatternType]:
- try:
- return re.compile(pattern)
- except re.error:
- raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
- def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
- if isinstance(input_value, IPv4Address):
- return input_value
- try:
- return IPv4Address(input_value)
- except ValueError:
- raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
- def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
- if isinstance(input_value, IPv6Address):
- return input_value
- try:
- return IPv6Address(input_value)
- except ValueError:
- raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
- def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
- """Assume IPv4Network initialised with a default `strict` argument.
- See more:
- https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
- """
- if isinstance(input_value, IPv4Network):
- return input_value
- try:
- return IPv4Network(input_value)
- except ValueError:
- raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
- def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
- """Assume IPv6Network initialised with a default `strict` argument.
- See more:
- https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
- """
- if isinstance(input_value, IPv6Network):
- return input_value
- try:
- return IPv6Network(input_value)
- except ValueError:
- raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
- def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
- if isinstance(input_value, IPv4Interface):
- return input_value
- try:
- return IPv4Interface(input_value)
- except ValueError:
- raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
- def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
- if isinstance(input_value, IPv6Interface):
- return input_value
- try:
- return IPv6Interface(input_value)
- except ValueError:
- raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
- def fraction_validator(input_value: Any, /) -> Fraction:
- if isinstance(input_value, Fraction):
- return input_value
- try:
- return Fraction(input_value)
- except ValueError:
- raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
- def forbid_inf_nan_check(x: Any) -> Any:
- if not math.isfinite(x):
- raise PydanticKnownError('finite_number')
- return x
- def _safe_repr(v: Any) -> int | float | str:
- """The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
- See tests/test_types.py::test_annotated_metadata_any_order for some context.
- """
- if isinstance(v, (int, float, str)):
- return v
- return repr(v)
- def greater_than_validator(x: Any, gt: Any) -> Any:
- try:
- if not (x > gt):
- raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
- return x
- except TypeError:
- raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
- def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
- try:
- if not (x >= ge):
- raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
- return x
- except TypeError:
- raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
- def less_than_validator(x: Any, lt: Any) -> Any:
- try:
- if not (x < lt):
- raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
- return x
- except TypeError:
- raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
- def less_than_or_equal_validator(x: Any, le: Any) -> Any:
- try:
- if not (x <= le):
- raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
- return x
- except TypeError:
- raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
- def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
- try:
- if x % multiple_of:
- raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
- return x
- except TypeError:
- raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
- def min_length_validator(x: Any, min_length: Any) -> Any:
- try:
- if not (len(x) >= min_length):
- raise PydanticKnownError(
- 'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
- )
- return x
- except TypeError:
- raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
- def max_length_validator(x: Any, max_length: Any) -> Any:
- try:
- if len(x) > max_length:
- raise PydanticKnownError(
- 'too_long',
- {'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
- )
- return x
- except TypeError:
- raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
- def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
- """Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
- This function handles both normalized and non-normalized Decimal instances.
- Example: Decimal('1.230') -> 4 digits, 3 decimal places
- Args:
- decimal (Decimal): The decimal number to analyze.
- Returns:
- tuple[int, int]: A tuple containing the number of decimal places and total digits.
- Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
- of the number of decimals and digits together.
- """
- try:
- decimal_tuple = decimal.as_tuple()
- assert isinstance(decimal_tuple.exponent, int)
- exponent = decimal_tuple.exponent
- num_digits = len(decimal_tuple.digits)
- if exponent >= 0:
- # A positive exponent adds that many trailing zeros
- # Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
- num_digits += exponent
- decimal_places = 0
- else:
- # If the absolute value of the negative exponent is larger than the
- # number of digits, then it's the same as the number of digits,
- # because it'll consume all the digits in digit_tuple and then
- # add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
- # Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
- # Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
- decimal_places = abs(exponent)
- num_digits = max(num_digits, decimal_places)
- return decimal_places, num_digits
- except (AssertionError, AttributeError):
- raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
- def max_digits_validator(x: Any, max_digits: Any) -> Any:
- try:
- _, num_digits = _extract_decimal_digits_info(x)
- _, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
- if (num_digits > max_digits) and (normalized_num_digits > max_digits):
- raise PydanticKnownError(
- 'decimal_max_digits',
- {'max_digits': max_digits},
- )
- return x
- except TypeError:
- raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
- def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
- try:
- decimal_places_, _ = _extract_decimal_digits_info(x)
- if decimal_places_ > decimal_places:
- normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
- if normalized_decimal_places > decimal_places:
- raise PydanticKnownError(
- 'decimal_max_places',
- {'decimal_places': decimal_places},
- )
- return x
- except TypeError:
- raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
- def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]:
- return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None))
- def defaultdict_validator(
- input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
- ) -> collections.defaultdict[Any, Any]:
- if isinstance(input_value, collections.defaultdict):
- default_factory = input_value.default_factory
- return collections.defaultdict(default_factory, handler(input_value))
- else:
- return collections.defaultdict(default_default_factory, handler(input_value))
- def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
- FieldInfo = import_cached_field_info()
- values_type_origin = get_origin(values_source_type)
- def infer_default() -> Callable[[], Any]:
- allowed_default_types: dict[Any, Any] = {
- tuple: tuple,
- collections.abc.Sequence: tuple,
- collections.abc.MutableSequence: list,
- list: list,
- typing.Sequence: list,
- set: set,
- typing.MutableSet: set,
- collections.abc.MutableSet: set,
- collections.abc.Set: frozenset,
- typing.MutableMapping: dict,
- typing.Mapping: dict,
- collections.abc.Mapping: dict,
- collections.abc.MutableMapping: dict,
- float: float,
- int: int,
- str: str,
- bool: bool,
- }
- values_type = values_type_origin or values_source_type
- instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
- if typing_objects.is_typevar(values_type):
- def type_var_default_factory() -> None:
- raise RuntimeError(
- 'Generic defaultdict cannot be used without a concrete value type or an'
- ' explicit default factory, ' + instructions
- )
- return type_var_default_factory
- elif values_type not in allowed_default_types:
- # a somewhat subjective set of types that have reasonable default values
- allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
- raise PydanticSchemaGenerationError(
- f'Unable to infer a default factory for keys of type {values_source_type}.'
- f' Only {allowed_msg} are supported, other types require an explicit default factory'
- ' ' + instructions
- )
- return allowed_default_types[values_type]
- # Assume Annotated[..., Field(...)]
- if typing_objects.is_annotated(values_type_origin):
- field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
- else:
- field_info = None
- if field_info and field_info.default_factory:
- # Assume the default factory does not take any argument:
- default_default_factory = cast(Callable[[], Any], field_info.default_factory)
- else:
- default_default_factory = infer_default()
- return default_default_factory
- def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo:
- if isinstance(value, ZoneInfo):
- return value
- try:
- return ZoneInfo(value)
- except (ZoneInfoNotFoundError, ValueError, TypeError):
- raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value})
- NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
- 'gt': greater_than_validator,
- 'ge': greater_than_or_equal_validator,
- 'lt': less_than_validator,
- 'le': less_than_or_equal_validator,
- 'multiple_of': multiple_of_validator,
- 'min_length': min_length_validator,
- 'max_length': max_length_validator,
- 'max_digits': max_digits_validator,
- 'decimal_places': decimal_places_validator,
- }
- IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
- IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
- IPv4Address: ip_v4_address_validator,
- IPv6Address: ip_v6_address_validator,
- IPv4Network: ip_v4_network_validator,
- IPv6Network: ip_v6_network_validator,
- IPv4Interface: ip_v4_interface_validator,
- IPv6Interface: ip_v6_interface_validator,
- }
- MAPPING_ORIGIN_MAP: dict[Any, Any] = {
- typing.DefaultDict: collections.defaultdict, # noqa: UP006
- collections.defaultdict: collections.defaultdict,
- typing.OrderedDict: collections.OrderedDict, # noqa: UP006
- collections.OrderedDict: collections.OrderedDict,
- typing_extensions.OrderedDict: collections.OrderedDict,
- typing.Counter: collections.Counter,
- collections.Counter: collections.Counter,
- # this doesn't handle subclasses of these
- typing.Mapping: dict,
- typing.MutableMapping: dict,
- # parametrized typing.{Mutable}Mapping creates one of these
- collections.abc.Mapping: dict,
- collections.abc.MutableMapping: dict,
- }
|