| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684 |
- import collections.abc
- import inspect
- import types
- from collections.abc import Callable
- from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass
- from functools import lru_cache, wraps
- from typing import (
- Annotated,
- Any,
- ForwardRef,
- Literal,
- Type,
- TypeVar,
- Union,
- get_args,
- get_origin,
- overload,
- )
- try:
- # Python 3.11+
- from typing import NotRequired, Required # type: ignore
- except ImportError:
- try:
- # In case typing_extensions is installed
- from typing_extensions import NotRequired, Required # type: ignore
- except ImportError:
- # Fallback: create dummy types that will never match
- Required = type("Required", (), {}) # type: ignore
- NotRequired = type("NotRequired", (), {}) # type: ignore
- from .errors import (
- StrictDataclassClassValidationError,
- StrictDataclassDefinitionError,
- StrictDataclassFieldValidationError,
- )
- Validator_T = Callable[[Any], None]
- T = TypeVar("T")
- TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any])
- _TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None)
- # The overload decorator helps type checkers understand the different return types
- @overload
- def strict(cls: Type[T]) -> Type[T]: ...
- @overload
- def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ...
- def strict(cls: Type[T] | None = None, *, accept_kwargs: bool = False) -> Type[T] | Callable[[Type[T]], Type[T]]:
- """
- Decorator to add strict validation to a dataclass.
- This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools
- recognize the class as a dataclass.
- Can be used with or without arguments:
- - `@strict`
- - `@strict(accept_kwargs=True)`
- Args:
- cls:
- The class to convert to a strict dataclass.
- accept_kwargs (`bool`, *optional*):
- If True, allows arbitrary keyword arguments in `__init__`. Defaults to False.
- Returns:
- The enhanced dataclass with strict validation on field assignment.
- Example:
- ```py
- >>> from dataclasses import dataclass
- >>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field
- >>> @as_validated_field
- >>> def positive_int(value: int):
- ... if not value >= 0:
- ... raise ValueError(f"Value must be positive, got {value}")
- >>> @strict(accept_kwargs=True)
- ... @dataclass
- ... class User:
- ... name: str
- ... age: int = positive_int(default=10)
- # Initialize
- >>> User(name="John")
- User(name='John', age=10)
- # Extra kwargs are accepted
- >>> User(name="John", age=30, lastname="Doe")
- User(name='John', age=30, *lastname='Doe')
- # Invalid type => raises
- >>> User(name="John", age="30")
- huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
- TypeError: Field 'age' expected int, got str (value: '30')
- # Invalid value => raises
- >>> User(name="John", age=-1)
- huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
- ValueError: Value must be positive, got -1
- ```
- """
- def wrap(cls: Type[T]) -> Type[T]:
- if not hasattr(cls, "__dataclass_fields__"):
- raise StrictDataclassDefinitionError(
- f"Class '{cls.__name__}' must be a dataclass before applying @strict."
- )
- # List and store validators
- field_validators: dict[str, list[Validator_T]] = {}
- for f in fields(cls): # type: ignore [arg-type]
- validators = []
- validators.append(_create_type_validator(f))
- custom_validator = f.metadata.get("validator")
- if custom_validator is not None:
- if not isinstance(custom_validator, list):
- custom_validator = [custom_validator]
- for validator in custom_validator:
- if not _is_validator(validator):
- raise StrictDataclassDefinitionError(
- f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument."
- )
- validators.extend(custom_validator)
- field_validators[f.name] = validators
- cls.__validators__ = field_validators # type: ignore
- # Override __setattr__ to validate fields on assignment
- original_setattr = cls.__setattr__
- def __strict_setattr__(self: Any, name: str, value: Any) -> None:
- """Custom __setattr__ method for strict dataclasses."""
- # Run all validators
- for validator in self.__validators__.get(name, []):
- try:
- validator(value)
- except (ValueError, TypeError) as e:
- raise StrictDataclassFieldValidationError(field=name, cause=e) from e
- # If validation passed, set the attribute
- original_setattr(self, name, value)
- cls.__setattr__ = __strict_setattr__ # type: ignore
- if accept_kwargs:
- # (optional) Override __init__ to accept arbitrary keyword arguments
- original_init = cls.__init__
- @wraps(original_init)
- def __init__(self, *args, **kwargs: Any) -> None:
- # Extract only the fields that are part of the dataclass
- dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type]
- standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields}
- # User shouldn't define custom `__init__` when `accepts_kwargs`, and instead
- # are advised to move field manipulation to `__post_init__` (e.g., derive new field from existing ones)
- # We need to call bare `__init__` here without `__post_init__` but the``original_init`` would call
- # post-init right away with no kwargs.
- if len(args) > 0:
- raise ValueError(
- f"When `accept_kwargs=True`, {cls.__name__} accepts only keyword arguments, "
- f"but found `{len(args)}` positional args."
- )
- for f in fields(cls): # type: ignore
- if f.name in standard_kwargs:
- setattr(self, f.name, standard_kwargs[f.name])
- elif f.default is not MISSING:
- setattr(self, f.name, f.default)
- elif f.default_factory is not MISSING:
- setattr(self, f.name, f.default_factory())
- else:
- raise TypeError(f"Missing required field - '{f.name}'")
- # Pass any additional kwargs to `__post_init__` and let the object
- # decide whether to set the attr or use for different purposes (e.g. BC checks)
- additional_kwargs = {}
- for name, value in kwargs.items():
- if name not in dataclass_fields:
- additional_kwargs[name] = value
- self.__post_init__(**additional_kwargs)
- cls.__init__ = __init__ # type: ignore
- # Define a default __post_init__ if not defined
- if not hasattr(cls, "__post_init__"):
- def __post_init__(self, **kwargs: Any) -> None:
- """Default __post_init__ to accept additional kwargs."""
- for name, value in kwargs.items():
- setattr(self, name, value)
- cls.__post_init__ = __post_init__ # type: ignore
- # (optional) Override __repr__ to include additional kwargs
- original_repr = cls.__repr__
- @wraps(original_repr)
- def __repr__(self) -> str:
- # Call the original __repr__ to get the standard fields
- standard_repr = original_repr(self)
- # Get additional kwargs
- additional_kwargs = [
- # add a '*' in front of additional kwargs to let the user know they are not part of the dataclass
- f"*{k}={v!r}"
- for k, v in self.__dict__.items()
- if k not in cls.__dataclass_fields__ # type: ignore [attr-defined]
- ]
- additional_repr = ", ".join(additional_kwargs)
- # Combine both representations
- return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr
- if cls.__dataclass_params__.repr is True: # type: ignore [attr-defined]
- cls.__repr__ = __repr__ # type: ignore
- # List all public methods starting with `validate_` => class validators.
- class_validators = []
- for name in dir(cls):
- if not name.startswith("validate_"):
- continue
- method = getattr(cls, name)
- if not callable(method):
- continue
- if len(inspect.signature(method).parameters) != 1:
- raise StrictDataclassDefinitionError(
- f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument."
- " Class validators must take only 'self' as an argument. Methods starting with 'validate_'"
- " are considered to be class validators."
- )
- class_validators.append(method)
- cls.__class_validators__ = class_validators # type: ignore
- # Add `validate` method to the class, but first check if it already exists
- def validate(self: T) -> None:
- """Run class validators on the instance."""
- for validator in cls.__class_validators__: # type: ignore [attr-defined]
- try:
- validator(self)
- except (ValueError, TypeError) as e:
- raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e
- # Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class
- # (in which case we just override it)
- validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined]
- if hasattr(cls, "validate"):
- if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined]
- raise StrictDataclassDefinitionError(
- f"Class '{cls.__name__}' already implements a method called 'validate'."
- " This method name is reserved when using the @strict decorator on a dataclass."
- " If you want to keep your own method, please rename it."
- )
- cls.validate = validate # type: ignore
- # Run class validators after initialization
- initial_init = cls.__init__
- @wraps(initial_init)
- def init_with_validate(self, *args, **kwargs) -> None:
- """Run class validators after initialization."""
- initial_init(self, *args, **kwargs) # type: ignore [call-arg]
- cls.validate(self) # type: ignore [attr-defined]
- setattr(cls, "__init__", init_with_validate)
- return cls
- # Return wrapped class or the decorator itself
- return wrap(cls) if cls is not None else wrap
- def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None:
- """
- Validate that a dictionary conforms to the types defined in a TypedDict class.
- Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator.
- Args:
- schema (`type[TypedDictType]`):
- The TypedDict class defining the expected structure and types.
- data (`dict`):
- The dictionary to validate.
- Raises:
- `StrictDataclassFieldValidationError`:
- If any field in the dictionary does not conform to the expected type.
- Example:
- ```py
- >>> from typing import Annotated, TypedDict
- >>> from huggingface_hub.dataclasses import validate_typed_dict
- >>> def positive_int(value: int):
- ... if not value >= 0:
- ... raise ValueError(f"Value must be positive, got {value}")
- >>> class User(TypedDict):
- ... name: str
- ... age: Annotated[int, positive_int]
- >>> # Valid data
- >>> validate_typed_dict(User, {"name": "John", "age": 30})
- >>> # Invalid type for age
- >>> validate_typed_dict(User, {"name": "John", "age": "30"})
- huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
- TypeError: Field 'age' expected int, got str (value: '30')
- >>> # Invalid value for age
- >>> validate_typed_dict(User, {"name": "John", "age": -1})
- huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
- ValueError: Value must be positive, got -1
- ```
- """
- # Convert typed dict to dataclass
- strict_cls = _build_strict_cls_from_typed_dict(schema)
- # Validate the data by instantiating the strict dataclass
- strict_cls(**data) # will raise if validation fails
- @lru_cache
- def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
- # Extract type hints from the TypedDict class
- type_hints = _get_typed_dict_annotations(schema)
- # If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired)
- if not getattr(schema, "__total__", True):
- for key, value in type_hints.items():
- origin = get_origin(value)
- if origin is Annotated:
- base, *meta = get_args(value)
- if not _is_required_or_notrequired(base):
- base = NotRequired[base]
- type_hints[key] = Annotated[tuple([base] + list(meta))] # type: ignore
- elif not _is_required_or_notrequired(value):
- type_hints[key] = NotRequired[value]
- # Convert type hints to dataclass fields
- fields = []
- for key, value in type_hints.items():
- if get_origin(value) is Annotated:
- base, *meta = get_args(value)
- fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]})))
- else:
- fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE)))
- # Create a strict dataclass from the TypedDict fields
- return strict(make_dataclass(schema.__name__, fields))
- def _get_typed_dict_annotations(schema: type[TypedDictType]) -> dict[str, Any]:
- """Extract type annotations from a TypedDict class."""
- try:
- # Available in Python 3.14+
- import annotationlib
- return annotationlib.get_annotations(schema)
- except ImportError:
- return {
- # We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail).
- # ForwardRefs are not validated by @strict anyway.
- name: value if value is not None else type(None)
- for name, value in schema.__dict__.get("__annotations__", {}).items()
- }
- def validated_field(
- validator: list[Validator_T] | Validator_T,
- default: Any | _MISSING_TYPE = MISSING,
- default_factory: Callable[[], Any] | _MISSING_TYPE = MISSING,
- init: bool = True,
- repr: bool = True,
- hash: bool | None = None,
- compare: bool = True,
- metadata: dict | None = None,
- **kwargs: Any,
- ) -> Any:
- """
- Create a dataclass field with a custom validator.
- Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator.
- Args:
- validator (`Callable` or `list[Callable]`):
- A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
- Can be a list of validators to apply multiple checks.
- **kwargs:
- Additional arguments to pass to `dataclasses.field()`.
- Returns:
- A field with the validator attached in metadata
- """
- if not isinstance(validator, list):
- validator = [validator]
- if metadata is None:
- metadata = {}
- metadata["validator"] = validator
- return field( # type: ignore
- default=default, # type: ignore [arg-type]
- default_factory=default_factory, # type: ignore [arg-type]
- init=init,
- repr=repr,
- hash=hash,
- compare=compare,
- metadata=metadata,
- **kwargs,
- )
- def as_validated_field(validator: Validator_T):
- """
- Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator).
- Args:
- validator (`Callable`):
- A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
- """
- def _inner(
- default: Any | _MISSING_TYPE = MISSING,
- default_factory: Callable[[], Any] | _MISSING_TYPE = MISSING,
- init: bool = True,
- repr: bool = True,
- hash: bool | None = None,
- compare: bool = True,
- metadata: dict | None = None,
- **kwargs: Any,
- ):
- return validated_field(
- validator,
- default=default,
- default_factory=default_factory,
- init=init,
- repr=repr,
- hash=hash,
- compare=compare,
- metadata=metadata,
- **kwargs,
- )
- return _inner
- def type_validator(name: str, value: Any, expected_type: Any) -> None:
- """Validate that 'value' matches 'expected_type'."""
- origin = get_origin(expected_type)
- args = get_args(expected_type)
- if expected_type is Any:
- return
- elif expected_type is None:
- _validate_none(name, value)
- elif validator := _BASIC_TYPE_VALIDATORS.get(origin):
- validator(name, value, args)
- elif isinstance(expected_type, type): # simple types
- _validate_simple_type(name, value, expected_type)
- elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
- return
- elif origin is Required:
- if value is _TYPED_DICT_DEFAULT_VALUE:
- raise TypeError(f"Field '{name}' is required but missing.")
- type_validator(name, value, args[0])
- elif origin is NotRequired:
- if value is _TYPED_DICT_DEFAULT_VALUE:
- return
- type_validator(name, value, args[0])
- else:
- raise TypeError(f"Unsupported type for field '{name}': {expected_type}")
- def _validate_none(name: str, value: Any) -> None:
- """Validate None type.
- 'None' is not a type, it's a special value. Type should be `NoneType` instead.
- But in type annotations 'None' is accepted so we must support it.
- """
- if value is not None:
- raise TypeError(f"Field '{name}' expected None, got {type(value).__name__}")
- def _validate_union(name: str, value: Any, args: tuple[Any, ...]) -> None:
- """Validate that value matches one of the types in a Union."""
- errors = []
- for t in args:
- try:
- type_validator(name, value, t)
- return # Valid if any type matches
- except TypeError as e:
- errors.append(str(e))
- raise TypeError(
- f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}"
- )
- def _validate_literal(name: str, value: Any, args: tuple[Any, ...]) -> None:
- """Validate Literal type."""
- if isinstance(value, bool):
- if value not in [arg for arg in args if isinstance(arg, bool)]:
- raise TypeError(f"Field '{name}' expected one of {args}, got {value}")
- elif isinstance(value, int):
- if value not in [arg for arg in args if isinstance(arg, int) and not isinstance(arg, bool)]:
- raise TypeError(f"Field '{name}' expected one of {args}, got {value}")
- elif value not in args:
- raise TypeError(f"Field '{name}' expected one of {args}, got {value}")
- def _validate_list(name: str, value: Any, args: tuple[Any, ...]) -> None:
- """Validate list[T] type."""
- if not isinstance(value, list):
- raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}")
- # Validate each item in the list
- item_type = args[0]
- for i, item in enumerate(value):
- try:
- type_validator(f"{name}[{i}]", item, item_type)
- except TypeError as e:
- raise TypeError(f"Invalid item at index {i} in list '{name}'") from e
- def _validate_dict(name: str, value: Any, args: tuple[Any, ...]) -> None:
- """Validate dict[K, V] type."""
- if not isinstance(value, dict):
- raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}")
- # Validate keys and values
- key_type, value_type = args
- for k, v in value.items():
- try:
- type_validator(f"{name}.key", k, key_type)
- type_validator(f"{name}[{k!r}]", v, value_type)
- except TypeError as e:
- raise TypeError(f"Invalid key or value in dict '{name}'") from e
- def _validate_tuple(name: str, value: Any, args: tuple[Any, ...]) -> None:
- """Validate Tuple type."""
- if not isinstance(value, tuple):
- raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}")
- # Handle variable-length tuples: tuple[T, ...]
- if len(args) == 2 and args[1] is Ellipsis:
- for i, item in enumerate(value):
- try:
- type_validator(f"{name}[{i}]", item, args[0])
- except TypeError as e:
- raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
- # Handle fixed-length tuples: tuple[T1, T2, ...]
- elif len(args) != len(value):
- raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}")
- else:
- for i, (item, expected) in enumerate(zip(value, args)):
- try:
- type_validator(f"{name}[{i}]", item, expected)
- except TypeError as e:
- raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
- def _validate_set(name: str, value: Any, args: tuple[Any, ...]) -> None:
- """Validate set[T] type."""
- if not isinstance(value, set):
- raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}")
- # Validate each item in the set
- item_type = args[0]
- for i, item in enumerate(value):
- try:
- type_validator(f"{name} item", item, item_type)
- except TypeError as e:
- raise TypeError(f"Invalid item in set '{name}'") from e
- def _validate_sequence(name: str, value: Any, args: tuple[Any, ...]) -> None:
- """Validate Sequence or Sequence[T] type."""
- if not isinstance(value, collections.abc.Sequence):
- raise TypeError(f"Field '{name}' expected a Sequence, got {type(value).__name__}")
- # If no type argument is provided (i.e., just `Sequence`), skip item validation
- if not args:
- return
- # Validate each item in the sequence
- item_type = args[0]
- for i, item in enumerate(value):
- try:
- type_validator(f"{name}[{i}]", item, item_type)
- except TypeError as e:
- raise TypeError(f"Invalid item at index {i} in sequence '{name}'") from e
- def _validate_simple_type(name: str, value: Any, expected_type: type) -> None:
- """Validate simple type (int, str, etc.)."""
- if expected_type is int and isinstance(value, bool):
- raise TypeError(
- f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})"
- )
- if not isinstance(value, expected_type):
- raise TypeError(
- f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})"
- )
- def _create_type_validator(field: Field) -> Validator_T:
- """Create a type validator function for a field."""
- # Hacky: we cannot use a lambda here because of reference issues
- def validator(value: Any) -> None:
- type_validator(field.name, value, field.type)
- return validator
- def _is_validator(validator: Any) -> bool:
- """Check if a function is a validator.
- A validator is a Callable that can be called with a single positional argument.
- The validator can have more arguments with default values.
- Basically, returns True if `validator(value)` is possible.
- """
- if not callable(validator):
- return False
- signature = inspect.signature(validator)
- parameters = list(signature.parameters.values())
- if len(parameters) == 0:
- return False
- if parameters[0].kind not in (
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.POSITIONAL_ONLY,
- inspect.Parameter.VAR_POSITIONAL,
- ):
- return False
- for parameter in parameters[1:]:
- if parameter.default == inspect.Parameter.empty:
- return False
- return True
- def _is_required_or_notrequired(type_hint: Any) -> bool:
- """Helper to check if a type is Required/NotRequired."""
- return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired))
- _BASIC_TYPE_VALIDATORS: dict[Any, Callable[[str, Any, tuple[Any, ...]], None]] = {
- Union: _validate_union,
- Literal: _validate_literal,
- list: _validate_list,
- dict: _validate_dict,
- tuple: _validate_tuple,
- set: _validate_set,
- collections.abc.Sequence: _validate_sequence,
- }
- # TODO: make it first class citizen when bumping to Python 3.10+
- _BASIC_TYPE_VALIDATORS[types.UnionType] = _validate_union # x | y syntax, available only Python 3.10+
- __all__ = [
- "strict",
- "validate_typed_dict",
- "validated_field",
- "Validator_T",
- "StrictDataclassClassValidationError",
- "StrictDataclassDefinitionError",
- "StrictDataclassFieldValidationError",
- ]
|