brain_dataclasses.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
  4. """
  5. Astroid hook for the dataclasses library.
  6. Support built-in dataclasses, pydantic.dataclasses, and marshmallow_dataclass-annotated
  7. dataclasses. References:
  8. - https://docs.python.org/3/library/dataclasses.html
  9. - https://pydantic-docs.helpmanual.io/usage/dataclasses/
  10. - https://lovasoa.github.io/marshmallow_dataclass/
  11. """
  12. from __future__ import annotations
  13. from collections.abc import Iterator
  14. from typing import Literal
  15. from astroid import bases, context, nodes
  16. from astroid.brain.helpers import is_class_var
  17. from astroid.builder import parse
  18. from astroid.const import PY313_PLUS
  19. from astroid.exceptions import AstroidSyntaxError, InferenceError, UseInferenceDefault
  20. from astroid.inference_tip import inference_tip
  21. from astroid.manager import AstroidManager
  22. from astroid.typing import InferenceResult
  23. from astroid.util import Uninferable, UninferableBase, safe_infer
  24. _FieldDefaultReturn = (
  25. None
  26. | tuple[Literal["default"], nodes.NodeNG]
  27. | tuple[Literal["default_factory"], nodes.Call]
  28. )
  29. DATACLASSES_DECORATORS = frozenset(("dataclass",))
  30. FIELD_NAME = "field"
  31. DATACLASS_MODULES = frozenset(
  32. ("dataclasses", "marshmallow_dataclass", "pydantic.dataclasses")
  33. )
  34. DEFAULT_FACTORY = "_HAS_DEFAULT_FACTORY" # based on typing.py
  35. def is_decorated_with_dataclass(
  36. node: nodes.ClassDef, decorator_names: frozenset[str] = DATACLASSES_DECORATORS
  37. ) -> bool:
  38. """Return True if a decorated node has a `dataclass` decorator applied."""
  39. if not (isinstance(node, nodes.ClassDef) and node.decorators):
  40. return False
  41. return any(
  42. _looks_like_dataclass_decorator(decorator_attribute, decorator_names)
  43. for decorator_attribute in node.decorators.nodes
  44. )
  45. def dataclass_transform(node: nodes.ClassDef) -> nodes.ClassDef | None:
  46. """Rewrite a dataclass to be easily understood by pylint."""
  47. node.is_dataclass = True
  48. for assign_node in _get_dataclass_attributes(node):
  49. name = assign_node.target.name
  50. rhs_node = nodes.Unknown(
  51. lineno=assign_node.lineno,
  52. col_offset=assign_node.col_offset,
  53. parent=assign_node,
  54. )
  55. rhs_node = AstroidManager().visit_transforms(rhs_node)
  56. node.instance_attrs[name] = [rhs_node]
  57. if not _check_generate_dataclass_init(node):
  58. return None
  59. kw_only_decorated = False
  60. if node.decorators.nodes:
  61. for decorator in node.decorators.nodes:
  62. if not isinstance(decorator, nodes.Call):
  63. kw_only_decorated = False
  64. break
  65. for keyword in decorator.keywords:
  66. if keyword.arg == "kw_only":
  67. kw_only_decorated = keyword.value.bool_value() is True
  68. init_str = _generate_dataclass_init(
  69. node,
  70. list(_get_dataclass_attributes(node, init=True)),
  71. kw_only_decorated,
  72. )
  73. try:
  74. init_node = parse(init_str)["__init__"]
  75. except AstroidSyntaxError:
  76. pass
  77. else:
  78. init_node.parent = node
  79. init_node.lineno, init_node.col_offset = None, None
  80. node.locals["__init__"] = [init_node]
  81. root = node.root()
  82. if DEFAULT_FACTORY not in root.locals:
  83. new_assign = parse(f"{DEFAULT_FACTORY} = object()").body[0]
  84. new_assign.parent = root
  85. root.locals[DEFAULT_FACTORY] = [new_assign.targets[0]]
  86. return node
  87. def _get_dataclass_attributes(
  88. node: nodes.ClassDef, init: bool = False
  89. ) -> Iterator[nodes.AnnAssign]:
  90. """Yield the AnnAssign nodes of dataclass attributes for the node.
  91. If init is True, also include InitVars.
  92. """
  93. for assign_node in node.body:
  94. if not (
  95. isinstance(assign_node, nodes.AnnAssign)
  96. and isinstance(assign_node.target, nodes.AssignName)
  97. ):
  98. continue
  99. # Annotation is never None
  100. if is_class_var(assign_node.annotation): # type: ignore[arg-type]
  101. continue
  102. if _is_keyword_only_sentinel(assign_node.annotation):
  103. continue
  104. # Annotation is never None
  105. if not init and _is_init_var(assign_node.annotation): # type: ignore[arg-type]
  106. continue
  107. yield assign_node
  108. def _check_generate_dataclass_init(node: nodes.ClassDef) -> bool:
  109. """Return True if we should generate an __init__ method for node.
  110. This is True when:
  111. - node doesn't define its own __init__ method
  112. - the dataclass decorator was called *without* the keyword argument init=False
  113. """
  114. if "__init__" in node.locals:
  115. return False
  116. found = None
  117. for decorator_attribute in node.decorators.nodes:
  118. if not isinstance(decorator_attribute, nodes.Call):
  119. continue
  120. if _looks_like_dataclass_decorator(decorator_attribute):
  121. found = decorator_attribute
  122. if found is None:
  123. return True
  124. # Check for keyword arguments of the form init=False
  125. return not any(
  126. keyword.arg == "init"
  127. and keyword.value.bool_value() is False # type: ignore[union-attr] # value is never None
  128. for keyword in found.keywords
  129. )
  130. def _find_arguments_from_base_classes(
  131. node: nodes.ClassDef,
  132. ) -> tuple[
  133. dict[str, tuple[str | None, str | None]], dict[str, tuple[str | None, str | None]]
  134. ]:
  135. """Iterate through all bases and get their typing and defaults."""
  136. pos_only_store: dict[str, tuple[str | None, str | None]] = {}
  137. kw_only_store: dict[str, tuple[str | None, str | None]] = {}
  138. # See TODO down below
  139. # all_have_defaults = True
  140. for base in reversed(node.mro()):
  141. if not base.is_dataclass:
  142. continue
  143. try:
  144. base_init: nodes.FunctionDef = base.locals["__init__"][0]
  145. except KeyError:
  146. continue
  147. pos_only, kw_only = base_init.args._get_arguments_data()
  148. for posarg, data in pos_only.items():
  149. # if data[1] is None:
  150. # if all_have_defaults and pos_only_store:
  151. # # TODO: This should return an Uninferable as this would raise
  152. # # a TypeError at runtime. However, transforms can't return
  153. # # Uninferables currently.
  154. # pass
  155. # all_have_defaults = False
  156. pos_only_store[posarg] = data
  157. for kwarg, data in kw_only.items():
  158. kw_only_store[kwarg] = data
  159. return pos_only_store, kw_only_store
  160. def _parse_arguments_into_strings(
  161. pos_only_store: dict[str, tuple[str | None, str | None]],
  162. kw_only_store: dict[str, tuple[str | None, str | None]],
  163. ) -> tuple[str, str]:
  164. """Parse positional and keyword arguments into strings for an __init__ method."""
  165. pos_only, kw_only = "", ""
  166. for pos_arg, data in pos_only_store.items():
  167. pos_only += pos_arg
  168. if data[0]:
  169. pos_only += ": " + data[0]
  170. if data[1]:
  171. pos_only += " = " + data[1]
  172. pos_only += ", "
  173. for kw_arg, data in kw_only_store.items():
  174. kw_only += kw_arg
  175. if data[0]:
  176. kw_only += ": " + data[0]
  177. if data[1]:
  178. kw_only += " = " + data[1]
  179. kw_only += ", "
  180. return pos_only, kw_only
  181. def _get_previous_field_default(node: nodes.ClassDef, name: str) -> nodes.NodeNG | None:
  182. """Get the default value of a previously defined field."""
  183. for base in reversed(node.mro()):
  184. if not base.is_dataclass:
  185. continue
  186. if name in base.locals:
  187. for assign in base.locals[name]:
  188. if (
  189. isinstance(assign.parent, nodes.AnnAssign)
  190. and assign.parent.value
  191. and isinstance(assign.parent.value, nodes.Call)
  192. and _looks_like_dataclass_field_call(assign.parent.value)
  193. ):
  194. default = _get_field_default(assign.parent.value)
  195. if default:
  196. return default[1]
  197. return None
  198. def _generate_dataclass_init(
  199. node: nodes.ClassDef, assigns: list[nodes.AnnAssign], kw_only_decorated: bool
  200. ) -> str:
  201. """Return an init method for a dataclass given the targets."""
  202. # pylint: disable = too-many-locals, too-many-branches, too-many-statements
  203. params: list[str] = []
  204. kw_only_params: list[str] = []
  205. assignments: list[str] = []
  206. prev_pos_only_store, prev_kw_only_store = _find_arguments_from_base_classes(node)
  207. for assign in assigns:
  208. name, annotation, value = assign.target.name, assign.annotation, assign.value
  209. # Check whether this assign is overriden by a property assignment
  210. property_node: nodes.FunctionDef | None = None
  211. for additional_assign in node.locals[name]:
  212. if not isinstance(additional_assign, nodes.FunctionDef):
  213. continue
  214. if not additional_assign.decorators:
  215. continue
  216. if "builtins.property" in additional_assign.decoratornames():
  217. property_node = additional_assign
  218. break
  219. is_field = isinstance(value, nodes.Call) and _looks_like_dataclass_field_call(
  220. value, check_scope=False
  221. )
  222. if is_field:
  223. # Skip any fields that have `init=False`
  224. if any(
  225. keyword.arg == "init" and (keyword.value.bool_value() is False)
  226. for keyword in value.keywords # type: ignore[union-attr] # value is never None
  227. ):
  228. # Also remove the name from the previous arguments to be inserted later
  229. prev_pos_only_store.pop(name, None)
  230. prev_kw_only_store.pop(name, None)
  231. continue
  232. if _is_init_var(annotation): # type: ignore[arg-type] # annotation is never None
  233. init_var = True
  234. if isinstance(annotation, nodes.Subscript):
  235. annotation = annotation.slice
  236. else:
  237. # Cannot determine type annotation for parameter from InitVar
  238. annotation = None
  239. assignment_str = ""
  240. else:
  241. init_var = False
  242. assignment_str = f"self.{name} = {name}"
  243. ann_str, default_str = None, None
  244. if annotation is not None:
  245. ann_str = annotation.as_string()
  246. if value:
  247. if is_field:
  248. result = _get_field_default(value) # type: ignore[arg-type]
  249. if result:
  250. default_type, default_node = result
  251. if default_type == "default":
  252. default_str = default_node.as_string()
  253. elif default_type == "default_factory":
  254. default_str = DEFAULT_FACTORY
  255. assignment_str = (
  256. f"self.{name} = {default_node.as_string()} "
  257. f"if {name} is {DEFAULT_FACTORY} else {name}"
  258. )
  259. else:
  260. default_str = value.as_string()
  261. elif property_node:
  262. # We set the result of the property call as default
  263. # This hides the fact that this would normally be a 'property object'
  264. # But we can't represent those as string
  265. try:
  266. # Call str to make sure also Uninferable gets stringified
  267. default_str = str(
  268. next(property_node.infer_call_result(None)).as_string()
  269. )
  270. except (InferenceError, StopIteration):
  271. pass
  272. else:
  273. # Even with `init=False` the default value still can be propogated to
  274. # later assignments. Creating weird signatures like:
  275. # (self, a: str = 1) -> None
  276. previous_default = _get_previous_field_default(node, name)
  277. if previous_default:
  278. default_str = previous_default.as_string()
  279. # Construct the param string to add to the init if necessary
  280. param_str = name
  281. if ann_str is not None:
  282. param_str += f": {ann_str}"
  283. if default_str is not None:
  284. param_str += f" = {default_str}"
  285. # If the field is a kw_only field, we need to add it to the kw_only_params
  286. # This overwrites whether or not the class is kw_only decorated
  287. if is_field:
  288. kw_only = [k for k in value.keywords if k.arg == "kw_only"] # type: ignore[union-attr]
  289. if kw_only:
  290. if kw_only[0].value.bool_value() is True:
  291. kw_only_params.append(param_str)
  292. else:
  293. params.append(param_str)
  294. continue
  295. # If kw_only decorated, we need to add all parameters to the kw_only_params
  296. if kw_only_decorated:
  297. if name in prev_kw_only_store:
  298. prev_kw_only_store[name] = (ann_str, default_str)
  299. else:
  300. kw_only_params.append(param_str)
  301. else:
  302. # If the name was previously seen, overwrite that data
  303. # pylint: disable-next=else-if-used
  304. if name in prev_pos_only_store:
  305. prev_pos_only_store[name] = (ann_str, default_str)
  306. elif name in prev_kw_only_store:
  307. params = [name, *params]
  308. prev_kw_only_store.pop(name)
  309. else:
  310. params.append(param_str)
  311. if not init_var:
  312. assignments.append(assignment_str)
  313. prev_pos_only, prev_kw_only = _parse_arguments_into_strings(
  314. prev_pos_only_store, prev_kw_only_store
  315. )
  316. # Construct the new init method paramter string
  317. # First we do the positional only parameters, making sure to add the
  318. # the self parameter and the comma to allow adding keyword only parameters
  319. params_string = "" if "self" in prev_pos_only else "self, "
  320. params_string += prev_pos_only + ", ".join(params)
  321. if not params_string.endswith(", "):
  322. params_string += ", "
  323. # Then we add the keyword only parameters
  324. if prev_kw_only or kw_only_params:
  325. params_string += "*, "
  326. params_string += f"{prev_kw_only}{', '.join(kw_only_params)}"
  327. assignments_string = "\n ".join(assignments) if assignments else "pass"
  328. return f"def __init__({params_string}) -> None:\n {assignments_string}"
  329. def infer_dataclass_attribute(
  330. node: nodes.Unknown, ctx: context.InferenceContext | None = None
  331. ) -> Iterator[InferenceResult]:
  332. """Inference tip for an Unknown node that was dynamically generated to
  333. represent a dataclass attribute.
  334. In the case that a default value is provided, that is inferred first.
  335. Then, an Instance of the annotated class is yielded.
  336. """
  337. assign = node.parent
  338. if not isinstance(assign, nodes.AnnAssign):
  339. yield Uninferable
  340. return
  341. annotation, value = assign.annotation, assign.value
  342. if value is not None:
  343. yield from value.infer(context=ctx)
  344. if annotation is not None:
  345. yield from _infer_instance_from_annotation(annotation, ctx=ctx)
  346. else:
  347. yield Uninferable
  348. def infer_dataclass_field_call(
  349. node: nodes.Call, ctx: context.InferenceContext | None = None
  350. ) -> Iterator[InferenceResult]:
  351. """Inference tip for dataclass field calls."""
  352. if not isinstance(node.parent, (nodes.AnnAssign, nodes.Assign)):
  353. raise UseInferenceDefault
  354. result = _get_field_default(node)
  355. if not result:
  356. yield Uninferable
  357. else:
  358. default_type, default = result
  359. if default_type == "default":
  360. yield from default.infer(context=ctx)
  361. else:
  362. new_call = parse(default.as_string()).body[0].value
  363. new_call.parent = node.parent
  364. yield from new_call.infer(context=ctx)
  365. def _looks_like_dataclass_decorator(
  366. node: nodes.NodeNG, decorator_names: frozenset[str] = DATACLASSES_DECORATORS
  367. ) -> bool:
  368. """Return True if node looks like a dataclass decorator.
  369. Uses inference to lookup the value of the node, and if that fails,
  370. matches against specific names.
  371. """
  372. if isinstance(node, nodes.Call): # decorator with arguments
  373. node = node.func
  374. try:
  375. inferred = next(node.infer())
  376. except (InferenceError, StopIteration):
  377. inferred = Uninferable
  378. if isinstance(inferred, UninferableBase):
  379. if isinstance(node, nodes.Name):
  380. return node.name in decorator_names
  381. if isinstance(node, nodes.Attribute):
  382. return node.attrname in decorator_names
  383. return False
  384. return (
  385. isinstance(inferred, nodes.FunctionDef)
  386. and inferred.name in decorator_names
  387. and inferred.root().name in DATACLASS_MODULES
  388. )
  389. def _looks_like_dataclass_attribute(node: nodes.Unknown) -> bool:
  390. """Return True if node was dynamically generated as the child of an AnnAssign
  391. statement.
  392. """
  393. parent = node.parent
  394. if not parent:
  395. return False
  396. scope = parent.scope()
  397. return (
  398. isinstance(parent, nodes.AnnAssign)
  399. and isinstance(scope, nodes.ClassDef)
  400. and is_decorated_with_dataclass(scope)
  401. )
  402. def _looks_like_dataclass_field_call(
  403. node: nodes.Call, check_scope: bool = True
  404. ) -> bool:
  405. """Return True if node is calling dataclasses field or Field
  406. from an AnnAssign statement directly in the body of a ClassDef.
  407. If check_scope is False, skips checking the statement and body.
  408. """
  409. if check_scope:
  410. stmt = node.statement()
  411. scope = stmt.scope()
  412. if not (
  413. isinstance(stmt, nodes.AnnAssign)
  414. and stmt.value is not None
  415. and isinstance(scope, nodes.ClassDef)
  416. and is_decorated_with_dataclass(scope)
  417. ):
  418. return False
  419. try:
  420. inferred = next(node.func.infer())
  421. except (InferenceError, StopIteration):
  422. return False
  423. if not isinstance(inferred, nodes.FunctionDef):
  424. return False
  425. return inferred.name == FIELD_NAME and inferred.root().name in DATACLASS_MODULES
  426. def _looks_like_dataclasses(node: nodes.Module) -> bool:
  427. return node.qname() == "dataclasses"
  428. def _resolve_private_replace_to_public(node: nodes.Module) -> None:
  429. """In python/cpython@6f3c138, a _replace() method was extracted from
  430. replace(), and this indirection made replace() uninferable."""
  431. if "_replace" in node.locals:
  432. node.locals["replace"] = node.locals["_replace"]
  433. def _get_field_default(field_call: nodes.Call) -> _FieldDefaultReturn:
  434. """Return a the default value of a field call, and the corresponding keyword
  435. argument name.
  436. field(default=...) results in the ... node
  437. field(default_factory=...) results in a Call node with func ... and no arguments
  438. If neither or both arguments are present, return ("", None) instead,
  439. indicating that there is not a valid default value.
  440. """
  441. default, default_factory = None, None
  442. for keyword in field_call.keywords:
  443. if keyword.arg == "default":
  444. default = keyword.value
  445. elif keyword.arg == "default_factory":
  446. default_factory = keyword.value
  447. if default is not None and default_factory is None:
  448. return "default", default
  449. if default is None and default_factory is not None:
  450. new_call = nodes.Call(
  451. lineno=field_call.lineno,
  452. col_offset=field_call.col_offset,
  453. parent=field_call.parent,
  454. end_lineno=field_call.end_lineno,
  455. end_col_offset=field_call.end_col_offset,
  456. )
  457. new_call.postinit(func=default_factory, args=[], keywords=[])
  458. return "default_factory", new_call
  459. return None
  460. def _is_keyword_only_sentinel(node: nodes.NodeNG) -> bool:
  461. """Return True if node is the KW_ONLY sentinel."""
  462. inferred = safe_infer(node)
  463. return (
  464. isinstance(inferred, bases.Instance)
  465. and inferred.qname() == "dataclasses._KW_ONLY_TYPE"
  466. )
  467. def _is_init_var(node: nodes.NodeNG) -> bool:
  468. """Return True if node is an InitVar, with or without subscripting."""
  469. try:
  470. inferred = next(node.infer())
  471. except (InferenceError, StopIteration):
  472. return False
  473. return getattr(inferred, "name", "") == "InitVar"
  474. # Allowed typing classes for which we support inferring instances
  475. _INFERABLE_TYPING_TYPES = frozenset(
  476. (
  477. "Dict",
  478. "FrozenSet",
  479. "List",
  480. "Set",
  481. "Tuple",
  482. )
  483. )
  484. def _infer_instance_from_annotation(
  485. node: nodes.NodeNG, ctx: context.InferenceContext | None = None
  486. ) -> Iterator[UninferableBase | bases.Instance]:
  487. """Infer an instance corresponding to the type annotation represented by node.
  488. Currently has limited support for the typing module.
  489. """
  490. klass = None
  491. try:
  492. klass = next(node.infer(context=ctx))
  493. except (InferenceError, StopIteration):
  494. yield Uninferable
  495. if not isinstance(klass, nodes.ClassDef):
  496. yield Uninferable
  497. elif klass.root().name in {
  498. "typing",
  499. "_collections_abc",
  500. "",
  501. }: # "" because of synthetic nodes in brain_typing.py
  502. if klass.name in _INFERABLE_TYPING_TYPES:
  503. yield klass.instantiate_class()
  504. else:
  505. yield Uninferable
  506. else:
  507. yield klass.instantiate_class()
  508. def register(manager: AstroidManager) -> None:
  509. if PY313_PLUS:
  510. manager.register_transform(
  511. nodes.Module,
  512. _resolve_private_replace_to_public,
  513. _looks_like_dataclasses,
  514. )
  515. manager.register_transform(
  516. nodes.ClassDef, dataclass_transform, is_decorated_with_dataclass
  517. )
  518. manager.register_transform(
  519. nodes.Call,
  520. inference_tip(infer_dataclass_field_call, raise_on_overwrite=True),
  521. _looks_like_dataclass_field_call,
  522. )
  523. manager.register_transform(
  524. nodes.Unknown,
  525. inference_tip(infer_dataclass_attribute, raise_on_overwrite=True),
  526. _looks_like_dataclass_attribute,
  527. )