validation.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. """Module containing validation mechanisms for transform parameters.
  2. This module provides a metaclass that enables parameter validation for transforms using
  3. Pydantic models. It intercepts the initialization of transform classes to validate their
  4. parameters against schema definitions, raising appropriate errors for invalid values and
  5. providing type conversion capabilities. This validation layer helps prevent runtime errors
  6. by catching configuration issues at initialization time.
  7. """
  8. from __future__ import annotations
  9. from inspect import Parameter, signature
  10. from typing import Any, Callable
  11. from warnings import warn
  12. from pydantic import BaseModel, ValidationError
  13. class ValidatedTransformMeta(type):
  14. """Metaclass that validates transform parameters during instantiation.
  15. This metaclass enables automatic validation of transform parameters using Pydantic models,
  16. ensuring proper typing and constraints are enforced before object creation.
  17. Args:
  18. original_init (Callable[..., Any]): Original __init__ method of the class.
  19. args (tuple[Any, ...]): Positional arguments passed to the __init__ method.
  20. kwargs (dict[str, Any]): Keyword arguments passed to the __init__ method.
  21. """
  22. @staticmethod
  23. def _process_init_parameters(
  24. original_init: Callable[..., Any],
  25. args: tuple[Any, ...],
  26. kwargs: dict[str, Any],
  27. ) -> tuple[dict[str, Any], list[str], bool]:
  28. init_params = signature(original_init).parameters
  29. param_names = list(init_params.keys())[1:] # Exclude 'self'
  30. full_kwargs: dict[str, Any] = dict(zip(param_names, args)) | kwargs
  31. # Get strict value before validation
  32. strict = full_kwargs.pop("strict", False)
  33. # Add default values if not provided
  34. for parameter_name, parameter in init_params.items():
  35. if (
  36. parameter_name != "self"
  37. and parameter_name not in full_kwargs
  38. and parameter.default is not Parameter.empty
  39. ):
  40. full_kwargs[parameter_name] = parameter.default
  41. return full_kwargs, param_names, strict
  42. @staticmethod
  43. def _validate_parameters(
  44. schema_cls: type[BaseModel],
  45. full_kwargs: dict[str, Any],
  46. param_names: list[str],
  47. strict: bool,
  48. ) -> dict[str, Any]:
  49. try:
  50. # Include strict parameter for schema validation
  51. schema_kwargs = {k: v for k, v in full_kwargs.items() if k in param_names}
  52. schema_kwargs["strict"] = strict
  53. config = schema_cls(**schema_kwargs)
  54. validated_kwargs = config.model_dump()
  55. validated_kwargs.pop("strict", None)
  56. except ValidationError as e:
  57. raise ValueError(str(e)) from e
  58. except Exception as e:
  59. if strict:
  60. raise ValueError(str(e)) from e
  61. warn(str(e), stacklevel=2)
  62. return {}
  63. else:
  64. return validated_kwargs
  65. @staticmethod
  66. def _get_default_values(init_params: dict[str, Parameter]) -> dict[str, Any]:
  67. validated_kwargs = {}
  68. for param_name, param in init_params.items():
  69. if param_name in {"self", "strict"}:
  70. continue
  71. if param.default is not Parameter.empty:
  72. validated_kwargs[param_name] = param.default
  73. return validated_kwargs
  74. def __new__(cls: type[Any], name: str, bases: tuple[type, ...], dct: dict[str, Any]) -> type[Any]:
  75. """This is a custom metaclass that validates the parameters of the class during instantiation.
  76. It is used to ensure that the parameters of the class are valid and that they are of the correct type.
  77. """
  78. if "InitSchema" in dct and issubclass(dct["InitSchema"], BaseModel):
  79. original_init: Callable[..., Any] | None = dct.get("__init__")
  80. if original_init is None:
  81. msg = "__init__ not found in class definition"
  82. raise ValueError(msg)
  83. original_sig = signature(original_init)
  84. def custom_init(self: Any, *args: Any, **kwargs: Any) -> None:
  85. full_kwargs, param_names, strict = cls._process_init_parameters(original_init, args, kwargs)
  86. validated_kwargs = cls._validate_parameters(
  87. dct["InitSchema"],
  88. full_kwargs,
  89. param_names,
  90. strict,
  91. ) or cls._get_default_values(signature(original_init).parameters)
  92. # Store and check invalid args
  93. invalid_args = [name_arg for name_arg in kwargs if name_arg not in param_names and name_arg != "strict"]
  94. original_init(self, **validated_kwargs)
  95. self.invalid_args = invalid_args
  96. if invalid_args:
  97. message = f"Argument(s) '{', '.join(invalid_args)}' are not valid for transform {name}"
  98. if strict:
  99. raise ValueError(message)
  100. warn(message, stacklevel=2)
  101. # Preserve the original signature and docstring
  102. custom_init.__signature__ = original_sig # type: ignore[attr-defined]
  103. custom_init.__doc__ = original_init.__doc__
  104. dct["__init__"] = custom_init
  105. return super().__new__(cls, name, bases, dct)