_schema_validator.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """Pluggable schema validator for pydantic."""
  2. from __future__ import annotations
  3. import functools
  4. from collections.abc import Iterable
  5. from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
  6. from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError
  7. from typing_extensions import ParamSpec
  8. if TYPE_CHECKING:
  9. from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath
  10. P = ParamSpec('P')
  11. R = TypeVar('R')
  12. Event = Literal['on_validate_python', 'on_validate_json', 'on_validate_strings']
  13. events: list[Event] = list(Event.__args__) # type: ignore
  14. def create_schema_validator(
  15. schema: CoreSchema,
  16. schema_type: Any,
  17. schema_type_module: str,
  18. schema_type_name: str,
  19. schema_kind: SchemaKind,
  20. config: CoreConfig | None = None,
  21. plugin_settings: dict[str, Any] | None = None,
  22. ) -> SchemaValidator | PluggableSchemaValidator:
  23. """Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
  24. Returns:
  25. If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
  26. """
  27. from . import SchemaTypePath
  28. from ._loader import get_plugins
  29. plugins = get_plugins()
  30. if plugins:
  31. return PluggableSchemaValidator(
  32. schema,
  33. schema_type,
  34. SchemaTypePath(schema_type_module, schema_type_name),
  35. schema_kind,
  36. config,
  37. plugins,
  38. plugin_settings or {},
  39. )
  40. else:
  41. return SchemaValidator(schema, config)
  42. class PluggableSchemaValidator:
  43. """Pluggable schema validator."""
  44. __slots__ = '_schema_validator', 'validate_json', 'validate_python', 'validate_strings'
  45. def __init__(
  46. self,
  47. schema: CoreSchema,
  48. schema_type: Any,
  49. schema_type_path: SchemaTypePath,
  50. schema_kind: SchemaKind,
  51. config: CoreConfig | None,
  52. plugins: Iterable[PydanticPluginProtocol],
  53. plugin_settings: dict[str, Any],
  54. ) -> None:
  55. self._schema_validator = SchemaValidator(schema, config)
  56. python_event_handlers: list[BaseValidateHandlerProtocol] = []
  57. json_event_handlers: list[BaseValidateHandlerProtocol] = []
  58. strings_event_handlers: list[BaseValidateHandlerProtocol] = []
  59. for plugin in plugins:
  60. try:
  61. p, j, s = plugin.new_schema_validator(
  62. schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
  63. )
  64. except TypeError as e: # pragma: no cover
  65. raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e
  66. if p is not None:
  67. python_event_handlers.append(p)
  68. if j is not None:
  69. json_event_handlers.append(j)
  70. if s is not None:
  71. strings_event_handlers.append(s)
  72. self.validate_python = build_wrapper(self._schema_validator.validate_python, python_event_handlers)
  73. self.validate_json = build_wrapper(self._schema_validator.validate_json, json_event_handlers)
  74. self.validate_strings = build_wrapper(self._schema_validator.validate_strings, strings_event_handlers)
  75. def __getattr__(self, name: str) -> Any:
  76. return getattr(self._schema_validator, name)
  77. def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandlerProtocol]) -> Callable[P, R]:
  78. if not event_handlers:
  79. return func
  80. else:
  81. on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter'))
  82. on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success'))
  83. on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error'))
  84. on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception'))
  85. @functools.wraps(func)
  86. def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
  87. for on_enter_handler in on_enters:
  88. on_enter_handler(*args, **kwargs)
  89. try:
  90. result = func(*args, **kwargs)
  91. except ValidationError as error:
  92. for on_error_handler in on_errors:
  93. on_error_handler(error)
  94. raise
  95. except Exception as exception:
  96. for on_exception_handler in on_exceptions:
  97. on_exception_handler(exception)
  98. raise
  99. else:
  100. for on_success_handler in on_successes:
  101. on_success_handler(result)
  102. return result
  103. return wrapper
  104. def filter_handlers(handler_cls: BaseValidateHandlerProtocol, method_name: str) -> bool:
  105. """Filter out handler methods which are not implemented by the plugin directly - e.g. are missing
  106. or are inherited from the protocol.
  107. """
  108. handler = getattr(handler_cls, method_name, None)
  109. if handler is None:
  110. return False
  111. elif handler.__module__ == 'pydantic.plugin':
  112. # this is the original handler, from the protocol due to runtime inheritance
  113. # we don't want to call it
  114. return False
  115. else:
  116. return True