dataclasses.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. import collections.abc
  2. import inspect
  3. import types
  4. from collections.abc import Callable
  5. from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass
  6. from functools import lru_cache, wraps
  7. from typing import (
  8. Annotated,
  9. Any,
  10. ForwardRef,
  11. Literal,
  12. Type,
  13. TypeVar,
  14. Union,
  15. get_args,
  16. get_origin,
  17. overload,
  18. )
  19. try:
  20. # Python 3.11+
  21. from typing import NotRequired, Required # type: ignore
  22. except ImportError:
  23. try:
  24. # In case typing_extensions is installed
  25. from typing_extensions import NotRequired, Required # type: ignore
  26. except ImportError:
  27. # Fallback: create dummy types that will never match
  28. Required = type("Required", (), {}) # type: ignore
  29. NotRequired = type("NotRequired", (), {}) # type: ignore
  30. from .errors import (
  31. StrictDataclassClassValidationError,
  32. StrictDataclassDefinitionError,
  33. StrictDataclassFieldValidationError,
  34. )
  35. Validator_T = Callable[[Any], None]
  36. T = TypeVar("T")
  37. TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any])
  38. _TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None)
  39. # The overload decorator helps type checkers understand the different return types
  40. @overload
  41. def strict(cls: Type[T]) -> Type[T]: ...
  42. @overload
  43. def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ...
  44. def strict(cls: Type[T] | None = None, *, accept_kwargs: bool = False) -> Type[T] | Callable[[Type[T]], Type[T]]:
  45. """
  46. Decorator to add strict validation to a dataclass.
  47. This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools
  48. recognize the class as a dataclass.
  49. Can be used with or without arguments:
  50. - `@strict`
  51. - `@strict(accept_kwargs=True)`
  52. Args:
  53. cls:
  54. The class to convert to a strict dataclass.
  55. accept_kwargs (`bool`, *optional*):
  56. If True, allows arbitrary keyword arguments in `__init__`. Defaults to False.
  57. Returns:
  58. The enhanced dataclass with strict validation on field assignment.
  59. Example:
  60. ```py
  61. >>> from dataclasses import dataclass
  62. >>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field
  63. >>> @as_validated_field
  64. >>> def positive_int(value: int):
  65. ... if not value >= 0:
  66. ... raise ValueError(f"Value must be positive, got {value}")
  67. >>> @strict(accept_kwargs=True)
  68. ... @dataclass
  69. ... class User:
  70. ... name: str
  71. ... age: int = positive_int(default=10)
  72. # Initialize
  73. >>> User(name="John")
  74. User(name='John', age=10)
  75. # Extra kwargs are accepted
  76. >>> User(name="John", age=30, lastname="Doe")
  77. User(name='John', age=30, *lastname='Doe')
  78. # Invalid type => raises
  79. >>> User(name="John", age="30")
  80. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  81. TypeError: Field 'age' expected int, got str (value: '30')
  82. # Invalid value => raises
  83. >>> User(name="John", age=-1)
  84. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  85. ValueError: Value must be positive, got -1
  86. ```
  87. """
  88. def wrap(cls: Type[T]) -> Type[T]:
  89. if not hasattr(cls, "__dataclass_fields__"):
  90. raise StrictDataclassDefinitionError(
  91. f"Class '{cls.__name__}' must be a dataclass before applying @strict."
  92. )
  93. # List and store validators
  94. field_validators: dict[str, list[Validator_T]] = {}
  95. for f in fields(cls): # type: ignore [arg-type]
  96. validators = []
  97. validators.append(_create_type_validator(f))
  98. custom_validator = f.metadata.get("validator")
  99. if custom_validator is not None:
  100. if not isinstance(custom_validator, list):
  101. custom_validator = [custom_validator]
  102. for validator in custom_validator:
  103. if not _is_validator(validator):
  104. raise StrictDataclassDefinitionError(
  105. f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument."
  106. )
  107. validators.extend(custom_validator)
  108. field_validators[f.name] = validators
  109. cls.__validators__ = field_validators # type: ignore
  110. # Override __setattr__ to validate fields on assignment
  111. original_setattr = cls.__setattr__
  112. def __strict_setattr__(self: Any, name: str, value: Any) -> None:
  113. """Custom __setattr__ method for strict dataclasses."""
  114. # Run all validators
  115. for validator in self.__validators__.get(name, []):
  116. try:
  117. validator(value)
  118. except (ValueError, TypeError) as e:
  119. raise StrictDataclassFieldValidationError(field=name, cause=e) from e
  120. # If validation passed, set the attribute
  121. original_setattr(self, name, value)
  122. cls.__setattr__ = __strict_setattr__ # type: ignore
  123. if accept_kwargs:
  124. # (optional) Override __init__ to accept arbitrary keyword arguments
  125. original_init = cls.__init__
  126. @wraps(original_init)
  127. def __init__(self, *args, **kwargs: Any) -> None:
  128. # Extract only the fields that are part of the dataclass
  129. dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type]
  130. standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields}
  131. # User shouldn't define custom `__init__` when `accepts_kwargs`, and instead
  132. # are advised to move field manipulation to `__post_init__` (e.g., derive new field from existing ones)
  133. # We need to call bare `__init__` here without `__post_init__` but the``original_init`` would call
  134. # post-init right away with no kwargs.
  135. if len(args) > 0:
  136. raise ValueError(
  137. f"When `accept_kwargs=True`, {cls.__name__} accepts only keyword arguments, "
  138. f"but found `{len(args)}` positional args."
  139. )
  140. for f in fields(cls): # type: ignore
  141. if f.name in standard_kwargs:
  142. setattr(self, f.name, standard_kwargs[f.name])
  143. elif f.default is not MISSING:
  144. setattr(self, f.name, f.default)
  145. elif f.default_factory is not MISSING:
  146. setattr(self, f.name, f.default_factory())
  147. else:
  148. raise TypeError(f"Missing required field - '{f.name}'")
  149. # Pass any additional kwargs to `__post_init__` and let the object
  150. # decide whether to set the attr or use for different purposes (e.g. BC checks)
  151. additional_kwargs = {}
  152. for name, value in kwargs.items():
  153. if name not in dataclass_fields:
  154. additional_kwargs[name] = value
  155. self.__post_init__(**additional_kwargs)
  156. cls.__init__ = __init__ # type: ignore
  157. # Define a default __post_init__ if not defined
  158. if not hasattr(cls, "__post_init__"):
  159. def __post_init__(self, **kwargs: Any) -> None:
  160. """Default __post_init__ to accept additional kwargs."""
  161. for name, value in kwargs.items():
  162. setattr(self, name, value)
  163. cls.__post_init__ = __post_init__ # type: ignore
  164. # (optional) Override __repr__ to include additional kwargs
  165. original_repr = cls.__repr__
  166. @wraps(original_repr)
  167. def __repr__(self) -> str:
  168. # Call the original __repr__ to get the standard fields
  169. standard_repr = original_repr(self)
  170. # Get additional kwargs
  171. additional_kwargs = [
  172. # add a '*' in front of additional kwargs to let the user know they are not part of the dataclass
  173. f"*{k}={v!r}"
  174. for k, v in self.__dict__.items()
  175. if k not in cls.__dataclass_fields__ # type: ignore [attr-defined]
  176. ]
  177. additional_repr = ", ".join(additional_kwargs)
  178. # Combine both representations
  179. return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr
  180. if cls.__dataclass_params__.repr is True: # type: ignore [attr-defined]
  181. cls.__repr__ = __repr__ # type: ignore
  182. # List all public methods starting with `validate_` => class validators.
  183. class_validators = []
  184. for name in dir(cls):
  185. if not name.startswith("validate_"):
  186. continue
  187. method = getattr(cls, name)
  188. if not callable(method):
  189. continue
  190. if len(inspect.signature(method).parameters) != 1:
  191. raise StrictDataclassDefinitionError(
  192. f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument."
  193. " Class validators must take only 'self' as an argument. Methods starting with 'validate_'"
  194. " are considered to be class validators."
  195. )
  196. class_validators.append(method)
  197. cls.__class_validators__ = class_validators # type: ignore
  198. # Add `validate` method to the class, but first check if it already exists
  199. def validate(self: T) -> None:
  200. """Run class validators on the instance."""
  201. for validator in cls.__class_validators__: # type: ignore [attr-defined]
  202. try:
  203. validator(self)
  204. except (ValueError, TypeError) as e:
  205. raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e
  206. # Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class
  207. # (in which case we just override it)
  208. validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined]
  209. if hasattr(cls, "validate"):
  210. if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined]
  211. raise StrictDataclassDefinitionError(
  212. f"Class '{cls.__name__}' already implements a method called 'validate'."
  213. " This method name is reserved when using the @strict decorator on a dataclass."
  214. " If you want to keep your own method, please rename it."
  215. )
  216. cls.validate = validate # type: ignore
  217. # Run class validators after initialization
  218. initial_init = cls.__init__
  219. @wraps(initial_init)
  220. def init_with_validate(self, *args, **kwargs) -> None:
  221. """Run class validators after initialization."""
  222. initial_init(self, *args, **kwargs) # type: ignore [call-arg]
  223. cls.validate(self) # type: ignore [attr-defined]
  224. setattr(cls, "__init__", init_with_validate)
  225. return cls
  226. # Return wrapped class or the decorator itself
  227. return wrap(cls) if cls is not None else wrap
  228. def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None:
  229. """
  230. Validate that a dictionary conforms to the types defined in a TypedDict class.
  231. Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator.
  232. Args:
  233. schema (`type[TypedDictType]`):
  234. The TypedDict class defining the expected structure and types.
  235. data (`dict`):
  236. The dictionary to validate.
  237. Raises:
  238. `StrictDataclassFieldValidationError`:
  239. If any field in the dictionary does not conform to the expected type.
  240. Example:
  241. ```py
  242. >>> from typing import Annotated, TypedDict
  243. >>> from huggingface_hub.dataclasses import validate_typed_dict
  244. >>> def positive_int(value: int):
  245. ... if not value >= 0:
  246. ... raise ValueError(f"Value must be positive, got {value}")
  247. >>> class User(TypedDict):
  248. ... name: str
  249. ... age: Annotated[int, positive_int]
  250. >>> # Valid data
  251. >>> validate_typed_dict(User, {"name": "John", "age": 30})
  252. >>> # Invalid type for age
  253. >>> validate_typed_dict(User, {"name": "John", "age": "30"})
  254. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  255. TypeError: Field 'age' expected int, got str (value: '30')
  256. >>> # Invalid value for age
  257. >>> validate_typed_dict(User, {"name": "John", "age": -1})
  258. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  259. ValueError: Value must be positive, got -1
  260. ```
  261. """
  262. # Convert typed dict to dataclass
  263. strict_cls = _build_strict_cls_from_typed_dict(schema)
  264. # Validate the data by instantiating the strict dataclass
  265. strict_cls(**data) # will raise if validation fails
  266. @lru_cache
  267. def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
  268. # Extract type hints from the TypedDict class
  269. type_hints = _get_typed_dict_annotations(schema)
  270. # If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired)
  271. if not getattr(schema, "__total__", True):
  272. for key, value in type_hints.items():
  273. origin = get_origin(value)
  274. if origin is Annotated:
  275. base, *meta = get_args(value)
  276. if not _is_required_or_notrequired(base):
  277. base = NotRequired[base]
  278. type_hints[key] = Annotated[tuple([base] + list(meta))] # type: ignore
  279. elif not _is_required_or_notrequired(value):
  280. type_hints[key] = NotRequired[value]
  281. # Convert type hints to dataclass fields
  282. fields = []
  283. for key, value in type_hints.items():
  284. if get_origin(value) is Annotated:
  285. base, *meta = get_args(value)
  286. fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]})))
  287. else:
  288. fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE)))
  289. # Create a strict dataclass from the TypedDict fields
  290. return strict(make_dataclass(schema.__name__, fields))
  291. def _get_typed_dict_annotations(schema: type[TypedDictType]) -> dict[str, Any]:
  292. """Extract type annotations from a TypedDict class."""
  293. try:
  294. # Available in Python 3.14+
  295. import annotationlib
  296. return annotationlib.get_annotations(schema)
  297. except ImportError:
  298. return {
  299. # We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail).
  300. # ForwardRefs are not validated by @strict anyway.
  301. name: value if value is not None else type(None)
  302. for name, value in schema.__dict__.get("__annotations__", {}).items()
  303. }
  304. def validated_field(
  305. validator: list[Validator_T] | Validator_T,
  306. default: Any | _MISSING_TYPE = MISSING,
  307. default_factory: Callable[[], Any] | _MISSING_TYPE = MISSING,
  308. init: bool = True,
  309. repr: bool = True,
  310. hash: bool | None = None,
  311. compare: bool = True,
  312. metadata: dict | None = None,
  313. **kwargs: Any,
  314. ) -> Any:
  315. """
  316. Create a dataclass field with a custom validator.
  317. Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator.
  318. Args:
  319. validator (`Callable` or `list[Callable]`):
  320. A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
  321. Can be a list of validators to apply multiple checks.
  322. **kwargs:
  323. Additional arguments to pass to `dataclasses.field()`.
  324. Returns:
  325. A field with the validator attached in metadata
  326. """
  327. if not isinstance(validator, list):
  328. validator = [validator]
  329. if metadata is None:
  330. metadata = {}
  331. metadata["validator"] = validator
  332. return field( # type: ignore
  333. default=default, # type: ignore [arg-type]
  334. default_factory=default_factory, # type: ignore [arg-type]
  335. init=init,
  336. repr=repr,
  337. hash=hash,
  338. compare=compare,
  339. metadata=metadata,
  340. **kwargs,
  341. )
  342. def as_validated_field(validator: Validator_T):
  343. """
  344. Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator).
  345. Args:
  346. validator (`Callable`):
  347. A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
  348. """
  349. def _inner(
  350. default: Any | _MISSING_TYPE = MISSING,
  351. default_factory: Callable[[], Any] | _MISSING_TYPE = MISSING,
  352. init: bool = True,
  353. repr: bool = True,
  354. hash: bool | None = None,
  355. compare: bool = True,
  356. metadata: dict | None = None,
  357. **kwargs: Any,
  358. ):
  359. return validated_field(
  360. validator,
  361. default=default,
  362. default_factory=default_factory,
  363. init=init,
  364. repr=repr,
  365. hash=hash,
  366. compare=compare,
  367. metadata=metadata,
  368. **kwargs,
  369. )
  370. return _inner
  371. def type_validator(name: str, value: Any, expected_type: Any) -> None:
  372. """Validate that 'value' matches 'expected_type'."""
  373. origin = get_origin(expected_type)
  374. args = get_args(expected_type)
  375. if expected_type is Any:
  376. return
  377. elif expected_type is None:
  378. _validate_none(name, value)
  379. elif validator := _BASIC_TYPE_VALIDATORS.get(origin):
  380. validator(name, value, args)
  381. elif isinstance(expected_type, type): # simple types
  382. _validate_simple_type(name, value, expected_type)
  383. elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
  384. return
  385. elif origin is Required:
  386. if value is _TYPED_DICT_DEFAULT_VALUE:
  387. raise TypeError(f"Field '{name}' is required but missing.")
  388. type_validator(name, value, args[0])
  389. elif origin is NotRequired:
  390. if value is _TYPED_DICT_DEFAULT_VALUE:
  391. return
  392. type_validator(name, value, args[0])
  393. else:
  394. raise TypeError(f"Unsupported type for field '{name}': {expected_type}")
  395. def _validate_none(name: str, value: Any) -> None:
  396. """Validate None type.
  397. 'None' is not a type, it's a special value. Type should be `NoneType` instead.
  398. But in type annotations 'None' is accepted so we must support it.
  399. """
  400. if value is not None:
  401. raise TypeError(f"Field '{name}' expected None, got {type(value).__name__}")
  402. def _validate_union(name: str, value: Any, args: tuple[Any, ...]) -> None:
  403. """Validate that value matches one of the types in a Union."""
  404. errors = []
  405. for t in args:
  406. try:
  407. type_validator(name, value, t)
  408. return # Valid if any type matches
  409. except TypeError as e:
  410. errors.append(str(e))
  411. raise TypeError(
  412. f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}"
  413. )
  414. def _validate_literal(name: str, value: Any, args: tuple[Any, ...]) -> None:
  415. """Validate Literal type."""
  416. if isinstance(value, bool):
  417. if value not in [arg for arg in args if isinstance(arg, bool)]:
  418. raise TypeError(f"Field '{name}' expected one of {args}, got {value}")
  419. elif isinstance(value, int):
  420. if value not in [arg for arg in args if isinstance(arg, int) and not isinstance(arg, bool)]:
  421. raise TypeError(f"Field '{name}' expected one of {args}, got {value}")
  422. elif value not in args:
  423. raise TypeError(f"Field '{name}' expected one of {args}, got {value}")
  424. def _validate_list(name: str, value: Any, args: tuple[Any, ...]) -> None:
  425. """Validate list[T] type."""
  426. if not isinstance(value, list):
  427. raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}")
  428. # Validate each item in the list
  429. item_type = args[0]
  430. for i, item in enumerate(value):
  431. try:
  432. type_validator(f"{name}[{i}]", item, item_type)
  433. except TypeError as e:
  434. raise TypeError(f"Invalid item at index {i} in list '{name}'") from e
  435. def _validate_dict(name: str, value: Any, args: tuple[Any, ...]) -> None:
  436. """Validate dict[K, V] type."""
  437. if not isinstance(value, dict):
  438. raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}")
  439. # Validate keys and values
  440. key_type, value_type = args
  441. for k, v in value.items():
  442. try:
  443. type_validator(f"{name}.key", k, key_type)
  444. type_validator(f"{name}[{k!r}]", v, value_type)
  445. except TypeError as e:
  446. raise TypeError(f"Invalid key or value in dict '{name}'") from e
  447. def _validate_tuple(name: str, value: Any, args: tuple[Any, ...]) -> None:
  448. """Validate Tuple type."""
  449. if not isinstance(value, tuple):
  450. raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}")
  451. # Handle variable-length tuples: tuple[T, ...]
  452. if len(args) == 2 and args[1] is Ellipsis:
  453. for i, item in enumerate(value):
  454. try:
  455. type_validator(f"{name}[{i}]", item, args[0])
  456. except TypeError as e:
  457. raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
  458. # Handle fixed-length tuples: tuple[T1, T2, ...]
  459. elif len(args) != len(value):
  460. raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}")
  461. else:
  462. for i, (item, expected) in enumerate(zip(value, args)):
  463. try:
  464. type_validator(f"{name}[{i}]", item, expected)
  465. except TypeError as e:
  466. raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
  467. def _validate_set(name: str, value: Any, args: tuple[Any, ...]) -> None:
  468. """Validate set[T] type."""
  469. if not isinstance(value, set):
  470. raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}")
  471. # Validate each item in the set
  472. item_type = args[0]
  473. for i, item in enumerate(value):
  474. try:
  475. type_validator(f"{name} item", item, item_type)
  476. except TypeError as e:
  477. raise TypeError(f"Invalid item in set '{name}'") from e
  478. def _validate_sequence(name: str, value: Any, args: tuple[Any, ...]) -> None:
  479. """Validate Sequence or Sequence[T] type."""
  480. if not isinstance(value, collections.abc.Sequence):
  481. raise TypeError(f"Field '{name}' expected a Sequence, got {type(value).__name__}")
  482. # If no type argument is provided (i.e., just `Sequence`), skip item validation
  483. if not args:
  484. return
  485. # Validate each item in the sequence
  486. item_type = args[0]
  487. for i, item in enumerate(value):
  488. try:
  489. type_validator(f"{name}[{i}]", item, item_type)
  490. except TypeError as e:
  491. raise TypeError(f"Invalid item at index {i} in sequence '{name}'") from e
  492. def _validate_simple_type(name: str, value: Any, expected_type: type) -> None:
  493. """Validate simple type (int, str, etc.)."""
  494. if expected_type is int and isinstance(value, bool):
  495. raise TypeError(
  496. f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})"
  497. )
  498. if not isinstance(value, expected_type):
  499. raise TypeError(
  500. f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})"
  501. )
  502. def _create_type_validator(field: Field) -> Validator_T:
  503. """Create a type validator function for a field."""
  504. # Hacky: we cannot use a lambda here because of reference issues
  505. def validator(value: Any) -> None:
  506. type_validator(field.name, value, field.type)
  507. return validator
  508. def _is_validator(validator: Any) -> bool:
  509. """Check if a function is a validator.
  510. A validator is a Callable that can be called with a single positional argument.
  511. The validator can have more arguments with default values.
  512. Basically, returns True if `validator(value)` is possible.
  513. """
  514. if not callable(validator):
  515. return False
  516. signature = inspect.signature(validator)
  517. parameters = list(signature.parameters.values())
  518. if len(parameters) == 0:
  519. return False
  520. if parameters[0].kind not in (
  521. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  522. inspect.Parameter.POSITIONAL_ONLY,
  523. inspect.Parameter.VAR_POSITIONAL,
  524. ):
  525. return False
  526. for parameter in parameters[1:]:
  527. if parameter.default == inspect.Parameter.empty:
  528. return False
  529. return True
  530. def _is_required_or_notrequired(type_hint: Any) -> bool:
  531. """Helper to check if a type is Required/NotRequired."""
  532. return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired))
  533. _BASIC_TYPE_VALIDATORS: dict[Any, Callable[[str, Any, tuple[Any, ...]], None]] = {
  534. Union: _validate_union,
  535. Literal: _validate_literal,
  536. list: _validate_list,
  537. dict: _validate_dict,
  538. tuple: _validate_tuple,
  539. set: _validate_set,
  540. collections.abc.Sequence: _validate_sequence,
  541. }
  542. # TODO: make it first class citizen when bumping to Python 3.10+
  543. _BASIC_TYPE_VALIDATORS[types.UnionType] = _validate_union # x | y syntax, available only Python 3.10+
  544. __all__ = [
  545. "strict",
  546. "validate_typed_dict",
  547. "validated_field",
  548. "Validator_T",
  549. "StrictDataclassClassValidationError",
  550. "StrictDataclassDefinitionError",
  551. "StrictDataclassFieldValidationError",
  552. ]