_validate_call.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from __future__ import annotations as _annotations
  2. import functools
  3. import inspect
  4. from collections.abc import Awaitable
  5. from functools import partial
  6. from typing import Any, Callable
  7. import pydantic_core
  8. from ..config import ConfigDict
  9. from ..plugin._schema_validator import create_schema_validator
  10. from ._config import ConfigWrapper
  11. from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
  12. from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
  13. def extract_function_name(func: ValidateCallSupportedTypes) -> str:
  14. """Extract the name of a `ValidateCallSupportedTypes` object."""
  15. return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
  16. def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
  17. """Extract the qualname of a `ValidateCallSupportedTypes` object."""
  18. return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
  19. def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
  20. """Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
  21. if inspect.iscoroutinefunction(wrapped):
  22. @functools.wraps(wrapped)
  23. async def wrapper_function(*args, **kwargs): # type: ignore
  24. return await wrapper(*args, **kwargs)
  25. else:
  26. @functools.wraps(wrapped)
  27. def wrapper_function(*args, **kwargs):
  28. return wrapper(*args, **kwargs)
  29. # We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
  30. wrapper_function.__name__ = extract_function_name(wrapped)
  31. wrapper_function.__qualname__ = extract_function_qualname(wrapped)
  32. wrapper_function.raw_function = wrapped # type: ignore
  33. return wrapper_function
  34. class ValidateCallWrapper:
  35. """This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
  36. __slots__ = (
  37. 'function',
  38. 'validate_return',
  39. 'schema_type',
  40. 'module',
  41. 'qualname',
  42. 'ns_resolver',
  43. 'config_wrapper',
  44. '__pydantic_complete__',
  45. '__pydantic_validator__',
  46. '__return_pydantic_validator__',
  47. )
  48. def __init__(
  49. self,
  50. function: ValidateCallSupportedTypes,
  51. config: ConfigDict | None,
  52. validate_return: bool,
  53. parent_namespace: MappingNamespace | None,
  54. ) -> None:
  55. self.function = function
  56. self.validate_return = validate_return
  57. if isinstance(function, partial):
  58. self.schema_type = function.func
  59. self.module = function.func.__module__
  60. else:
  61. self.schema_type = function
  62. self.module = function.__module__
  63. self.qualname = extract_function_qualname(function)
  64. self.ns_resolver = NsResolver(
  65. namespaces_tuple=ns_for_function(self.schema_type, parent_namespace=parent_namespace)
  66. )
  67. self.config_wrapper = ConfigWrapper(config)
  68. if not self.config_wrapper.defer_build:
  69. self._create_validators()
  70. else:
  71. self.__pydantic_complete__ = False
  72. def _create_validators(self) -> None:
  73. gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
  74. schema = gen_schema.clean_schema(gen_schema.generate_schema(self.function))
  75. core_config = self.config_wrapper.core_config(title=self.qualname)
  76. self.__pydantic_validator__ = create_schema_validator(
  77. schema,
  78. self.schema_type,
  79. self.module,
  80. self.qualname,
  81. 'validate_call',
  82. core_config,
  83. self.config_wrapper.plugin_settings,
  84. )
  85. if self.validate_return:
  86. signature = inspect.signature(self.function)
  87. return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
  88. gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
  89. schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
  90. validator = create_schema_validator(
  91. schema,
  92. self.schema_type,
  93. self.module,
  94. self.qualname,
  95. 'validate_call',
  96. core_config,
  97. self.config_wrapper.plugin_settings,
  98. )
  99. if inspect.iscoroutinefunction(self.function):
  100. async def return_val_wrapper(aw: Awaitable[Any]) -> None:
  101. return validator.validate_python(await aw)
  102. self.__return_pydantic_validator__ = return_val_wrapper
  103. else:
  104. self.__return_pydantic_validator__ = validator.validate_python
  105. else:
  106. self.__return_pydantic_validator__ = None
  107. self.__pydantic_complete__ = True
  108. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  109. if not self.__pydantic_complete__:
  110. self._create_validators()
  111. res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
  112. if self.__return_pydantic_validator__:
  113. return self.__return_pydantic_validator__(res)
  114. else:
  115. return res