| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- """Module containing validation mechanisms for transform parameters.
- This module provides a metaclass that enables parameter validation for transforms using
- Pydantic models. It intercepts the initialization of transform classes to validate their
- parameters against schema definitions, raising appropriate errors for invalid values and
- providing type conversion capabilities. This validation layer helps prevent runtime errors
- by catching configuration issues at initialization time.
- """
- from __future__ import annotations
- from inspect import Parameter, signature
- from typing import Any, Callable
- from warnings import warn
- from pydantic import BaseModel, ValidationError
- class ValidatedTransformMeta(type):
- """Metaclass that validates transform parameters during instantiation.
- This metaclass enables automatic validation of transform parameters using Pydantic models,
- ensuring proper typing and constraints are enforced before object creation.
- Args:
- original_init (Callable[..., Any]): Original __init__ method of the class.
- args (tuple[Any, ...]): Positional arguments passed to the __init__ method.
- kwargs (dict[str, Any]): Keyword arguments passed to the __init__ method.
- """
- @staticmethod
- def _process_init_parameters(
- original_init: Callable[..., Any],
- args: tuple[Any, ...],
- kwargs: dict[str, Any],
- ) -> tuple[dict[str, Any], list[str], bool]:
- init_params = signature(original_init).parameters
- param_names = list(init_params.keys())[1:] # Exclude 'self'
- full_kwargs: dict[str, Any] = dict(zip(param_names, args)) | kwargs
- # Get strict value before validation
- strict = full_kwargs.pop("strict", False)
- # Add default values if not provided
- for parameter_name, parameter in init_params.items():
- if (
- parameter_name != "self"
- and parameter_name not in full_kwargs
- and parameter.default is not Parameter.empty
- ):
- full_kwargs[parameter_name] = parameter.default
- return full_kwargs, param_names, strict
- @staticmethod
- def _validate_parameters(
- schema_cls: type[BaseModel],
- full_kwargs: dict[str, Any],
- param_names: list[str],
- strict: bool,
- ) -> dict[str, Any]:
- try:
- # Include strict parameter for schema validation
- schema_kwargs = {k: v for k, v in full_kwargs.items() if k in param_names}
- schema_kwargs["strict"] = strict
- config = schema_cls(**schema_kwargs)
- validated_kwargs = config.model_dump()
- validated_kwargs.pop("strict", None)
- except ValidationError as e:
- raise ValueError(str(e)) from e
- except Exception as e:
- if strict:
- raise ValueError(str(e)) from e
- warn(str(e), stacklevel=2)
- return {}
- else:
- return validated_kwargs
- @staticmethod
- def _get_default_values(init_params: dict[str, Parameter]) -> dict[str, Any]:
- validated_kwargs = {}
- for param_name, param in init_params.items():
- if param_name in {"self", "strict"}:
- continue
- if param.default is not Parameter.empty:
- validated_kwargs[param_name] = param.default
- return validated_kwargs
- def __new__(cls: type[Any], name: str, bases: tuple[type, ...], dct: dict[str, Any]) -> type[Any]:
- """This is a custom metaclass that validates the parameters of the class during instantiation.
- It is used to ensure that the parameters of the class are valid and that they are of the correct type.
- """
- if "InitSchema" in dct and issubclass(dct["InitSchema"], BaseModel):
- original_init: Callable[..., Any] | None = dct.get("__init__")
- if original_init is None:
- msg = "__init__ not found in class definition"
- raise ValueError(msg)
- original_sig = signature(original_init)
- def custom_init(self: Any, *args: Any, **kwargs: Any) -> None:
- full_kwargs, param_names, strict = cls._process_init_parameters(original_init, args, kwargs)
- validated_kwargs = cls._validate_parameters(
- dct["InitSchema"],
- full_kwargs,
- param_names,
- strict,
- ) or cls._get_default_values(signature(original_init).parameters)
- # Store and check invalid args
- invalid_args = [name_arg for name_arg in kwargs if name_arg not in param_names and name_arg != "strict"]
- original_init(self, **validated_kwargs)
- self.invalid_args = invalid_args
- if invalid_args:
- message = f"Argument(s) '{', '.join(invalid_args)}' are not valid for transform {name}"
- if strict:
- raise ValueError(message)
- warn(message, stacklevel=2)
- # Preserve the original signature and docstring
- custom_init.__signature__ = original_sig # type: ignore[attr-defined]
- custom_init.__doc__ = original_init.__doc__
- dct["__init__"] = custom_init
- return super().__new__(cls, name, bases, dct)
|