_discriminated_union.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. from __future__ import annotations as _annotations
  2. from collections.abc import Hashable, Sequence
  3. from typing import TYPE_CHECKING, Any, cast
  4. from pydantic_core import CoreSchema, core_schema
  5. from ..errors import PydanticUserError
  6. from . import _core_utils
  7. from ._core_utils import (
  8. CoreSchemaField,
  9. )
  10. if TYPE_CHECKING:
  11. from ..types import Discriminator
  12. from ._core_metadata import CoreMetadata
  13. class MissingDefinitionForUnionRef(Exception):
  14. """Raised when applying a discriminated union discriminator to a schema
  15. requires a definition that is not yet defined
  16. """
  17. def __init__(self, ref: str) -> None:
  18. self.ref = ref
  19. super().__init__(f'Missing definition for ref {self.ref!r}')
  20. def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
  21. metadata = cast('CoreMetadata', schema.setdefault('metadata', {}))
  22. metadata['pydantic_internal_union_discriminator'] = discriminator
  23. def apply_discriminator(
  24. schema: core_schema.CoreSchema,
  25. discriminator: str | Discriminator,
  26. definitions: dict[str, core_schema.CoreSchema] | None = None,
  27. ) -> core_schema.CoreSchema:
  28. """Applies the discriminator and returns a new core schema.
  29. Args:
  30. schema: The input schema.
  31. discriminator: The name of the field which will serve as the discriminator.
  32. definitions: A mapping of schema ref to schema.
  33. Returns:
  34. The new core schema.
  35. Raises:
  36. TypeError:
  37. - If `discriminator` is used with invalid union variant.
  38. - If `discriminator` is used with `Union` type with one variant.
  39. - If `discriminator` value mapped to multiple choices.
  40. MissingDefinitionForUnionRef:
  41. If the definition for ref is missing.
  42. PydanticUserError:
  43. - If a model in union doesn't have a discriminator field.
  44. - If discriminator field has a non-string alias.
  45. - If discriminator fields have different aliases.
  46. - If discriminator field not of type `Literal`.
  47. """
  48. from ..types import Discriminator
  49. if isinstance(discriminator, Discriminator):
  50. if isinstance(discriminator.discriminator, str):
  51. discriminator = discriminator.discriminator
  52. else:
  53. return discriminator._convert_schema(schema)
  54. return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
  55. class _ApplyInferredDiscriminator:
  56. """This class is used to convert an input schema containing a union schema into one where that union is
  57. replaced with a tagged-union, with all the associated debugging and performance benefits.
  58. This is done by:
  59. * Validating that the input schema is compatible with the provided discriminator
  60. * Introspecting the schema to determine which discriminator values should map to which union choices
  61. * Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
  62. I have chosen to implement the conversion algorithm in this class, rather than a function,
  63. to make it easier to maintain state while recursively walking the provided CoreSchema.
  64. """
  65. def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]):
  66. # `discriminator` should be the name of the field which will serve as the discriminator.
  67. # It must be the python name of the field, and *not* the field's alias. Note that as of now,
  68. # all members of a discriminated union _must_ use a field with the same name as the discriminator.
  69. # This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
  70. self.discriminator = discriminator
  71. # `definitions` should contain a mapping of schema ref to schema for all schemas which might
  72. # be referenced by some choice
  73. self.definitions = definitions
  74. # `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
  75. #
  76. # Note: following the v1 implementation, we currently disallow the use of different aliases
  77. # for different choices. This is not a limitation of pydantic_core, but if we try to handle
  78. # this, the inference logic gets complicated very quickly, and could result in confusing
  79. # debugging challenges for users making subtle mistakes.
  80. #
  81. # Rather than trying to do the most powerful inference possible, I think we should eventually
  82. # expose a way to more-manually control the way the TaggedUnionSchema is constructed through
  83. # the use of a new type which would be placed as an Annotation on the Union type. This would
  84. # provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
  85. # more complex cases, without over-complicating the inference logic for the common cases.
  86. self._discriminator_alias: str | None = None
  87. # `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
  88. # If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
  89. # constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
  90. # Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
  91. # that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
  92. # python side, but resolves the issue that `None` cannot correspond to any discriminator values.
  93. self._should_be_nullable = False
  94. # `_is_nullable` is used to track if the final produced schema will definitely be nullable;
  95. # we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
  96. # as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
  97. # the final value in another nullable schema.
  98. #
  99. # This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
  100. # to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
  101. self._is_nullable = False
  102. # `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
  103. # from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
  104. # algorithm may add more choices to this stack as (nested) unions are encountered.
  105. self._choices_to_handle: list[core_schema.CoreSchema] = []
  106. # `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
  107. # in the output TaggedUnionSchema that will replace the union from the input schema
  108. self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
  109. # `_used` is changed to True after applying the discriminator to prevent accidental reuse
  110. self._used = False
  111. def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
  112. """Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
  113. to this class.
  114. Args:
  115. schema: The input schema.
  116. Returns:
  117. The new core schema.
  118. Raises:
  119. TypeError:
  120. - If `discriminator` is used with invalid union variant.
  121. - If `discriminator` is used with `Union` type with one variant.
  122. - If `discriminator` value mapped to multiple choices.
  123. ValueError:
  124. If the definition for ref is missing.
  125. PydanticUserError:
  126. - If a model in union doesn't have a discriminator field.
  127. - If discriminator field has a non-string alias.
  128. - If discriminator fields have different aliases.
  129. - If discriminator field not of type `Literal`.
  130. """
  131. assert not self._used
  132. schema = self._apply_to_root(schema)
  133. if self._should_be_nullable and not self._is_nullable:
  134. schema = core_schema.nullable_schema(schema)
  135. self._used = True
  136. return schema
  137. def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
  138. """This method handles the outer-most stage of recursion over the input schema:
  139. unwrapping nullable or definitions schemas, and calling the `_handle_choice`
  140. method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
  141. """
  142. if schema['type'] == 'nullable':
  143. self._is_nullable = True
  144. wrapped = self._apply_to_root(schema['schema'])
  145. nullable_wrapper = schema.copy()
  146. nullable_wrapper['schema'] = wrapped
  147. return nullable_wrapper
  148. if schema['type'] == 'definitions':
  149. wrapped = self._apply_to_root(schema['schema'])
  150. definitions_wrapper = schema.copy()
  151. definitions_wrapper['schema'] = wrapped
  152. return definitions_wrapper
  153. if schema['type'] != 'union':
  154. # If the schema is not a union, it probably means it just had a single member and
  155. # was flattened by pydantic_core.
  156. # However, it still may make sense to apply the discriminator to this schema,
  157. # as a way to get discriminated-union-style error messages, so we allow this here.
  158. schema = core_schema.union_schema([schema])
  159. # Reverse the choices list before extending the stack so that they get handled in the order they occur
  160. choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]]
  161. self._choices_to_handle.extend(choices_schemas)
  162. while self._choices_to_handle:
  163. choice = self._choices_to_handle.pop()
  164. self._handle_choice(choice)
  165. if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator:
  166. # * We need to annotate `discriminator` as a union here to handle both branches of this conditional
  167. # * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
  168. # invariance of list, and because list[list[str | int]] is the type of the discriminator argument
  169. # to tagged_union_schema below
  170. # * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
  171. # interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
  172. # is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
  173. discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]]
  174. else:
  175. discriminator = self.discriminator
  176. return core_schema.tagged_union_schema(
  177. choices=self._tagged_union_choices,
  178. discriminator=discriminator,
  179. custom_error_type=schema.get('custom_error_type'),
  180. custom_error_message=schema.get('custom_error_message'),
  181. custom_error_context=schema.get('custom_error_context'),
  182. strict=False,
  183. from_attributes=True,
  184. ref=schema.get('ref'),
  185. metadata=schema.get('metadata'),
  186. serialization=schema.get('serialization'),
  187. )
  188. def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
  189. """This method handles the "middle" stage of recursion over the input schema.
  190. Specifically, it is responsible for handling each choice of the outermost union
  191. (and any "coalesced" choices obtained from inner unions).
  192. Here, "handling" entails:
  193. * Coalescing nested unions and compatible tagged-unions
  194. * Tracking the presence of 'none' and 'nullable' schemas occurring as choices
  195. * Validating that each allowed discriminator value maps to a unique choice
  196. * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
  197. """
  198. if choice['type'] == 'definition-ref':
  199. if choice['schema_ref'] not in self.definitions:
  200. raise MissingDefinitionForUnionRef(choice['schema_ref'])
  201. if choice['type'] == 'none':
  202. self._should_be_nullable = True
  203. elif choice['type'] == 'definitions':
  204. self._handle_choice(choice['schema'])
  205. elif choice['type'] == 'nullable':
  206. self._should_be_nullable = True
  207. self._handle_choice(choice['schema']) # unwrap the nullable schema
  208. elif choice['type'] == 'union':
  209. # Reverse the choices list before extending the stack so that they get handled in the order they occur
  210. choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
  211. self._choices_to_handle.extend(choices_schemas)
  212. elif choice['type'] not in {
  213. 'model',
  214. 'typed-dict',
  215. 'tagged-union',
  216. 'lax-or-strict',
  217. 'dataclass',
  218. 'dataclass-args',
  219. 'definition-ref',
  220. } and not _core_utils.is_function_with_inner_schema(choice):
  221. # We should eventually handle 'definition-ref' as well
  222. err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
  223. if choice['type'] == 'list':
  224. err_str += (
  225. ' If you are making use of a list of union types, make sure the discriminator is applied to the '
  226. 'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
  227. )
  228. raise TypeError(err_str)
  229. else:
  230. if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
  231. # In this case, this inner tagged-union is compatible with the outer tagged-union,
  232. # and its choices can be coalesced into the outer TaggedUnionSchema.
  233. subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
  234. # Reverse the choices list before extending the stack so that they get handled in the order they occur
  235. self._choices_to_handle.extend(subchoices[::-1])
  236. return
  237. inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None)
  238. self._set_unique_choice_for_values(choice, inferred_discriminator_values)
  239. def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool:
  240. """This method returns a boolean indicating whether the discriminator for the `choice`
  241. is the same as that being used for the outermost tagged union. This is used to
  242. determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
  243. or whether it should be treated as a separate (nested) choice.
  244. """
  245. inner_discriminator = choice['discriminator']
  246. return inner_discriminator == self.discriminator or (
  247. isinstance(inner_discriminator, list)
  248. and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
  249. )
  250. def _infer_discriminator_values_for_choice( # noqa C901
  251. self, choice: core_schema.CoreSchema, source_name: str | None
  252. ) -> list[str | int]:
  253. """This function recurses over `choice`, extracting all discriminator values that should map to this choice.
  254. `model_name` is accepted for the purpose of producing useful error messages.
  255. """
  256. if choice['type'] == 'definitions':
  257. return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
  258. elif _core_utils.is_function_with_inner_schema(choice):
  259. return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
  260. elif choice['type'] == 'lax-or-strict':
  261. return sorted(
  262. set(
  263. self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
  264. + self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
  265. )
  266. )
  267. elif choice['type'] == 'tagged-union':
  268. values: list[str | int] = []
  269. # Ignore str/int "choices" since these are just references to other choices
  270. subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
  271. for subchoice in subchoices:
  272. subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None)
  273. values.extend(subchoice_values)
  274. return values
  275. elif choice['type'] == 'union':
  276. values = []
  277. for subchoice in choice['choices']:
  278. subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice
  279. subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None)
  280. values.extend(subchoice_values)
  281. return values
  282. elif choice['type'] == 'nullable':
  283. self._should_be_nullable = True
  284. return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None)
  285. elif choice['type'] == 'model':
  286. return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
  287. elif choice['type'] == 'dataclass':
  288. return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
  289. elif choice['type'] == 'model-fields':
  290. return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name)
  291. elif choice['type'] == 'dataclass-args':
  292. return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name)
  293. elif choice['type'] == 'typed-dict':
  294. return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name)
  295. elif choice['type'] == 'definition-ref':
  296. schema_ref = choice['schema_ref']
  297. if schema_ref not in self.definitions:
  298. raise MissingDefinitionForUnionRef(schema_ref)
  299. return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
  300. else:
  301. err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
  302. if choice['type'] == 'list':
  303. err_str += (
  304. ' If you are making use of a list of union types, make sure the discriminator is applied to the '
  305. 'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
  306. )
  307. raise TypeError(err_str)
  308. def _infer_discriminator_values_for_typed_dict_choice(
  309. self, choice: core_schema.TypedDictSchema, source_name: str | None = None
  310. ) -> list[str | int]:
  311. """This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
  312. for the sake of readability.
  313. """
  314. source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}'
  315. field = choice['fields'].get(self.discriminator)
  316. if field is None:
  317. raise PydanticUserError(
  318. f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
  319. )
  320. return self._infer_discriminator_values_for_field(field, source)
  321. def _infer_discriminator_values_for_model_choice(
  322. self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
  323. ) -> list[str | int]:
  324. source = 'ModelFields' if source_name is None else f'Model {source_name!r}'
  325. field = choice['fields'].get(self.discriminator)
  326. if field is None:
  327. raise PydanticUserError(
  328. f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
  329. )
  330. return self._infer_discriminator_values_for_field(field, source)
  331. def _infer_discriminator_values_for_dataclass_choice(
  332. self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
  333. ) -> list[str | int]:
  334. source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}'
  335. for field in choice['fields']:
  336. if field['name'] == self.discriminator:
  337. break
  338. else:
  339. raise PydanticUserError(
  340. f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
  341. )
  342. return self._infer_discriminator_values_for_field(field, source)
  343. def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]:
  344. if field['type'] == 'computed-field':
  345. # This should never occur as a discriminator, as it is only relevant to serialization
  346. return []
  347. alias = field.get('validation_alias', self.discriminator)
  348. if not isinstance(alias, str):
  349. raise PydanticUserError(
  350. f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
  351. )
  352. if self._discriminator_alias is None:
  353. self._discriminator_alias = alias
  354. elif self._discriminator_alias != alias:
  355. raise PydanticUserError(
  356. f'Aliases for discriminator {self.discriminator!r} must be the same '
  357. f'(got {alias}, {self._discriminator_alias})',
  358. code='discriminator-alias',
  359. )
  360. return self._infer_discriminator_values_for_inner_schema(field['schema'], source)
  361. def _infer_discriminator_values_for_inner_schema(
  362. self, schema: core_schema.CoreSchema, source: str
  363. ) -> list[str | int]:
  364. """When inferring discriminator values for a field, we typically extract the expected values from a literal
  365. schema. This function does that, but also handles nested unions and defaults.
  366. """
  367. if schema['type'] == 'literal':
  368. return schema['expected']
  369. elif schema['type'] == 'union':
  370. # Generally when multiple values are allowed they should be placed in a single `Literal`, but
  371. # we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
  372. # For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
  373. values: list[Any] = []
  374. for choice in schema['choices']:
  375. choice_schema = choice[0] if isinstance(choice, tuple) else choice
  376. choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source)
  377. values.extend(choice_values)
  378. return values
  379. elif schema['type'] == 'default':
  380. # This will happen if the field has a default value; we ignore it while extracting the discriminator values
  381. return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
  382. elif schema['type'] == 'function-after':
  383. # After validators don't affect the discriminator values
  384. return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
  385. elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}:
  386. validator_type = repr(schema['type'].split('-')[1])
  387. raise PydanticUserError(
  388. f'Cannot use a mode={validator_type} validator in the'
  389. f' discriminator field {self.discriminator!r} of {source}',
  390. code='discriminator-validator',
  391. )
  392. else:
  393. raise PydanticUserError(
  394. f'{source} needs field {self.discriminator!r} to be of type `Literal`',
  395. code='discriminator-needs-literal',
  396. )
  397. def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None:
  398. """This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
  399. provided `choice`, validating that none of these values already map to another (different) choice.
  400. """
  401. for discriminator_value in values:
  402. if discriminator_value in self._tagged_union_choices:
  403. # It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
  404. # Because tagged_union_choices may map values to other values, we need to walk the choices dict
  405. # until we get to a "real" choice, and confirm that is equal to the one assigned.
  406. existing_choice = self._tagged_union_choices[discriminator_value]
  407. if existing_choice != choice:
  408. raise TypeError(
  409. f'Value {discriminator_value!r} for discriminator '
  410. f'{self.discriminator!r} mapped to multiple choices'
  411. )
  412. else:
  413. self._tagged_union_choices[discriminator_value] = choice