_base_nodes.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
  4. """This module contains some base nodes that can be inherited for the different nodes.
  5. Previously these were called Mixin nodes.
  6. """
  7. from __future__ import annotations
  8. import itertools
  9. from collections.abc import Callable, Generator, Iterator
  10. from functools import cached_property, lru_cache, partial
  11. from typing import TYPE_CHECKING, Any, ClassVar
  12. from astroid import bases, nodes, util
  13. from astroid.context import (
  14. CallContext,
  15. InferenceContext,
  16. bind_context_to_node,
  17. )
  18. from astroid.exceptions import (
  19. AttributeInferenceError,
  20. InferenceError,
  21. )
  22. from astroid.interpreter import dunder_lookup
  23. from astroid.nodes.node_ng import NodeNG
  24. from astroid.typing import InferenceResult
  25. if TYPE_CHECKING:
  26. from astroid.nodes.node_classes import LocalsDictNodeNG
  27. GetFlowFactory = Callable[
  28. [
  29. InferenceResult,
  30. InferenceResult | None,
  31. nodes.AugAssign | nodes.BinOp,
  32. InferenceResult,
  33. InferenceResult | None,
  34. InferenceContext,
  35. InferenceContext,
  36. ],
  37. list[partial[Generator[InferenceResult]]],
  38. ]
  39. class Statement(NodeNG):
  40. """Statement node adding a few attributes.
  41. NOTE: This class is part of the public API of 'astroid.nodes'.
  42. """
  43. is_statement = True
  44. """Whether this node indicates a statement."""
  45. def next_sibling(self):
  46. """The next sibling statement node.
  47. :returns: The next sibling statement node.
  48. :rtype: NodeNG or None
  49. """
  50. stmts = self.parent.child_sequence(self)
  51. index = stmts.index(self)
  52. try:
  53. return stmts[index + 1]
  54. except IndexError:
  55. return None
  56. def previous_sibling(self):
  57. """The previous sibling statement.
  58. :returns: The previous sibling statement node.
  59. :rtype: NodeNG or None
  60. """
  61. stmts = self.parent.child_sequence(self)
  62. index = stmts.index(self)
  63. if index >= 1:
  64. return stmts[index - 1]
  65. return None
  66. class NoChildrenNode(NodeNG):
  67. """Base nodes for nodes with no children, e.g. Pass."""
  68. def get_children(self) -> Iterator[NodeNG]:
  69. yield from ()
  70. class FilterStmtsBaseNode(NodeNG):
  71. """Base node for statement filtering and assignment type."""
  72. def _get_filtered_stmts(self, _, node, _stmts, mystmt: Statement | None):
  73. """Method used in _filter_stmts to get statements and trigger break."""
  74. if self.statement() is mystmt:
  75. # original node's statement is the assignment, only keep
  76. # current node (gen exp, list comp)
  77. return [node], True
  78. return _stmts, False
  79. def assign_type(self):
  80. return self
  81. class AssignTypeNode(NodeNG):
  82. """Base node for nodes that can 'assign' such as AnnAssign."""
  83. def assign_type(self):
  84. return self
  85. def _get_filtered_stmts(self, lookup_node, node, _stmts, mystmt: Statement | None):
  86. """Method used in filter_stmts."""
  87. if self is mystmt:
  88. return _stmts, True
  89. if self.statement() is mystmt:
  90. # original node's statement is the assignment, only keep
  91. # current node (gen exp, list comp)
  92. return [node], True
  93. return _stmts, False
  94. class ParentAssignNode(AssignTypeNode):
  95. """Base node for nodes whose assign_type is determined by the parent node."""
  96. def assign_type(self):
  97. return self.parent.assign_type()
  98. class ImportNode(FilterStmtsBaseNode, NoChildrenNode, Statement):
  99. """Base node for From and Import Nodes."""
  100. modname: str | None
  101. """The module that is being imported from.
  102. This is ``None`` for relative imports.
  103. """
  104. names: list[tuple[str, str | None]]
  105. """What is being imported from the module.
  106. Each entry is a :class:`tuple` of the name being imported,
  107. and the alias that the name is assigned to (if any).
  108. """
  109. def _infer_name(self, frame, name):
  110. return name
  111. def do_import_module(self, modname: str | None = None) -> nodes.Module:
  112. """Return the ast for a module whose name is <modname> imported by <self>."""
  113. mymodule = self.root()
  114. level: int | None = getattr(self, "level", None) # Import has no level
  115. if modname is None:
  116. modname = self.modname
  117. # If the module ImportNode is importing is a module with the same name
  118. # as the file that contains the ImportNode we don't want to use the cache
  119. # to make sure we use the import system to get the correct module.
  120. if (
  121. modname
  122. # pylint: disable-next=no-member # pylint doesn't recognize type of mymodule
  123. and mymodule.relative_to_absolute_name(modname, level) == mymodule.name
  124. ):
  125. use_cache = False
  126. else:
  127. use_cache = True
  128. # pylint: disable-next=no-member # pylint doesn't recognize type of mymodule
  129. return mymodule.import_module(
  130. modname,
  131. level=level,
  132. relative_only=bool(level and level >= 1),
  133. use_cache=use_cache,
  134. )
  135. def real_name(self, asname: str) -> str:
  136. """Get name from 'as' name."""
  137. for name, _asname in self.names:
  138. if name == "*":
  139. return asname
  140. if not _asname:
  141. name = name.split(".", 1)[0]
  142. _asname = name
  143. if asname == _asname:
  144. return name
  145. raise AttributeInferenceError(
  146. "Could not find original name for {attribute} in {target!r}",
  147. target=self,
  148. attribute=asname,
  149. )
  150. class MultiLineBlockNode(NodeNG):
  151. """Base node for multi-line blocks, e.g. For and FunctionDef.
  152. Note that this does not apply to every node with a `body` field.
  153. For instance, an If node has a multi-line body, but the body of an
  154. IfExpr is not multi-line, and hence cannot contain Return nodes,
  155. Assign nodes, etc.
  156. """
  157. _multi_line_block_fields: ClassVar[tuple[str, ...]] = ()
  158. @cached_property
  159. def _multi_line_blocks(self):
  160. return tuple(getattr(self, field) for field in self._multi_line_block_fields)
  161. def _get_return_nodes_skip_functions(self):
  162. for block in self._multi_line_blocks:
  163. for child_node in block:
  164. if child_node.is_function:
  165. continue
  166. yield from child_node._get_return_nodes_skip_functions()
  167. def _get_yield_nodes_skip_functions(self):
  168. for block in self._multi_line_blocks:
  169. for child_node in block:
  170. if child_node.is_function:
  171. continue
  172. yield from child_node._get_yield_nodes_skip_functions()
  173. def _get_yield_nodes_skip_lambdas(self):
  174. for block in self._multi_line_blocks:
  175. for child_node in block:
  176. if child_node.is_lambda:
  177. continue
  178. yield from child_node._get_yield_nodes_skip_lambdas()
  179. @cached_property
  180. def _assign_nodes_in_scope(self) -> list[nodes.Assign]:
  181. children_assign_nodes = (
  182. child_node._assign_nodes_in_scope
  183. for block in self._multi_line_blocks
  184. for child_node in block
  185. )
  186. return list(itertools.chain.from_iterable(children_assign_nodes))
  187. class MultiLineWithElseBlockNode(MultiLineBlockNode):
  188. """Base node for multi-line blocks that can have else statements."""
  189. @cached_property
  190. def blockstart_tolineno(self):
  191. return self.lineno
  192. def _elsed_block_range(
  193. self, lineno: int, orelse: list[nodes.NodeNG], last: int | None = None
  194. ) -> tuple[int, int]:
  195. """Handle block line numbers range for try/finally, for, if and while
  196. statements.
  197. """
  198. if lineno == self.fromlineno:
  199. return lineno, lineno
  200. if orelse:
  201. if lineno >= orelse[0].fromlineno:
  202. return lineno, orelse[-1].tolineno
  203. return lineno, orelse[0].fromlineno - 1
  204. return lineno, last or self.tolineno
  205. class LookupMixIn(NodeNG):
  206. """Mixin to look up a name in the right scope."""
  207. @lru_cache # noqa
  208. def lookup(self, name: str) -> tuple[LocalsDictNodeNG, list[NodeNG]]:
  209. """Lookup where the given variable is assigned.
  210. The lookup starts from self's scope. If self is not a frame itself
  211. and the name is found in the inner frame locals, statements will be
  212. filtered to remove ignorable statements according to self's location.
  213. :param name: The name of the variable to find assignments for.
  214. :returns: The scope node and the list of assignments associated to the
  215. given name according to the scope where it has been found (locals,
  216. globals or builtin).
  217. """
  218. return self.scope().scope_lookup(self, name)
  219. def ilookup(self, name):
  220. """Lookup the inferred values of the given variable.
  221. :param name: The variable name to find values for.
  222. :type name: str
  223. :returns: The inferred values of the statements returned from
  224. :meth:`lookup`.
  225. :rtype: iterable
  226. """
  227. frame, stmts = self.lookup(name)
  228. context = InferenceContext()
  229. return bases._infer_stmts(stmts, context, frame)
  230. def _reflected_name(name) -> str:
  231. return "__r" + name[2:]
  232. def _augmented_name(name) -> str:
  233. return "__i" + name[2:]
  234. BIN_OP_METHOD = {
  235. "+": "__add__",
  236. "-": "__sub__",
  237. "/": "__truediv__",
  238. "//": "__floordiv__",
  239. "*": "__mul__",
  240. "**": "__pow__",
  241. "%": "__mod__",
  242. "&": "__and__",
  243. "|": "__or__",
  244. "^": "__xor__",
  245. "<<": "__lshift__",
  246. ">>": "__rshift__",
  247. "@": "__matmul__",
  248. }
  249. REFLECTED_BIN_OP_METHOD = {
  250. key: _reflected_name(value) for (key, value) in BIN_OP_METHOD.items()
  251. }
  252. AUGMENTED_OP_METHOD = {
  253. key + "=": _augmented_name(value) for (key, value) in BIN_OP_METHOD.items()
  254. }
  255. class OperatorNode(NodeNG):
  256. @staticmethod
  257. def _filter_operation_errors(
  258. infer_callable: Callable[
  259. [InferenceContext | None],
  260. Generator[InferenceResult | util.BadOperationMessage],
  261. ],
  262. context: InferenceContext | None,
  263. error: type[util.BadOperationMessage],
  264. ) -> Generator[InferenceResult]:
  265. for result in infer_callable(context):
  266. if isinstance(result, error):
  267. # For the sake of .infer(), we don't care about operation
  268. # errors, which is the job of a linter. So return something
  269. # which shows that we can't infer the result.
  270. yield util.Uninferable
  271. else:
  272. yield result
  273. @staticmethod
  274. def _is_not_implemented(const) -> bool:
  275. """Check if the given const node is NotImplemented."""
  276. return isinstance(const, nodes.Const) and const.value is NotImplemented
  277. @staticmethod
  278. def _infer_old_style_string_formatting(
  279. instance: nodes.Const, other: nodes.NodeNG, context: InferenceContext
  280. ) -> tuple[util.UninferableBase | nodes.Const]:
  281. """Infer the result of '"string" % ...'.
  282. TODO: Instead of returning Uninferable we should rely
  283. on the call to '%' to see if the result is actually uninferable.
  284. """
  285. if isinstance(other, nodes.Tuple):
  286. if util.Uninferable in other.elts:
  287. return (util.Uninferable,)
  288. inferred_positional = [util.safe_infer(i, context) for i in other.elts]
  289. if all(isinstance(i, nodes.Const) for i in inferred_positional):
  290. values = tuple(i.value for i in inferred_positional)
  291. else:
  292. values = None
  293. elif isinstance(other, nodes.Dict):
  294. values: dict[Any, Any] = {}
  295. for pair in other.items:
  296. key = util.safe_infer(pair[0], context)
  297. if not isinstance(key, nodes.Const):
  298. return (util.Uninferable,)
  299. value = util.safe_infer(pair[1], context)
  300. if not isinstance(value, nodes.Const):
  301. return (util.Uninferable,)
  302. values[key.value] = value.value
  303. elif isinstance(other, nodes.Const):
  304. values = other.value
  305. else:
  306. return (util.Uninferable,)
  307. try:
  308. return (nodes.const_factory(instance.value % values),)
  309. except (TypeError, KeyError, ValueError):
  310. return (util.Uninferable,)
  311. @staticmethod
  312. def _invoke_binop_inference(
  313. instance: InferenceResult,
  314. opnode: nodes.AugAssign | nodes.BinOp,
  315. op: str,
  316. other: InferenceResult,
  317. context: InferenceContext,
  318. method_name: str,
  319. ) -> Generator[InferenceResult]:
  320. """Invoke binary operation inference on the given instance."""
  321. methods = dunder_lookup.lookup(instance, method_name)
  322. context = bind_context_to_node(context, instance)
  323. method = methods[0]
  324. context.callcontext.callee = method
  325. if (
  326. isinstance(instance, nodes.Const)
  327. and isinstance(instance.value, str)
  328. and op == "%"
  329. ):
  330. return iter(
  331. OperatorNode._infer_old_style_string_formatting(
  332. instance, other, context
  333. )
  334. )
  335. try:
  336. inferred = next(method.infer(context=context))
  337. except StopIteration as e:
  338. raise InferenceError(node=method, context=context) from e
  339. if isinstance(inferred, util.UninferableBase):
  340. raise InferenceError
  341. if not isinstance(
  342. instance,
  343. (nodes.Const, nodes.Tuple, nodes.List, nodes.ClassDef, bases.Instance),
  344. ):
  345. raise InferenceError # pragma: no cover # Used as a failsafe
  346. return instance.infer_binary_op(opnode, op, other, context, inferred)
  347. @staticmethod
  348. def _aug_op(
  349. instance: InferenceResult,
  350. opnode: nodes.AugAssign,
  351. op: str,
  352. other: InferenceResult,
  353. context: InferenceContext,
  354. reverse: bool = False,
  355. ) -> partial[Generator[InferenceResult]]:
  356. """Get an inference callable for an augmented binary operation."""
  357. method_name = AUGMENTED_OP_METHOD[op]
  358. return partial(
  359. OperatorNode._invoke_binop_inference,
  360. instance=instance,
  361. op=op,
  362. opnode=opnode,
  363. other=other,
  364. context=context,
  365. method_name=method_name,
  366. )
  367. @staticmethod
  368. def _bin_op(
  369. instance: InferenceResult,
  370. opnode: nodes.AugAssign | nodes.BinOp,
  371. op: str,
  372. other: InferenceResult,
  373. context: InferenceContext,
  374. reverse: bool = False,
  375. ) -> partial[Generator[InferenceResult]]:
  376. """Get an inference callable for a normal binary operation.
  377. If *reverse* is True, then the reflected method will be used instead.
  378. """
  379. if reverse:
  380. method_name = REFLECTED_BIN_OP_METHOD[op]
  381. else:
  382. method_name = BIN_OP_METHOD[op]
  383. return partial(
  384. OperatorNode._invoke_binop_inference,
  385. instance=instance,
  386. op=op,
  387. opnode=opnode,
  388. other=other,
  389. context=context,
  390. method_name=method_name,
  391. )
  392. @staticmethod
  393. def _bin_op_or_union_type(
  394. left: bases.UnionType | nodes.ClassDef | nodes.Const,
  395. right: bases.UnionType | nodes.ClassDef | nodes.Const,
  396. ) -> Generator[InferenceResult]:
  397. """Create a new UnionType instance for binary or, e.g. int | str."""
  398. yield bases.UnionType(left, right)
  399. @staticmethod
  400. def _get_binop_contexts(context, left, right):
  401. """Get contexts for binary operations.
  402. This will return two inference contexts, the first one
  403. for x.__op__(y), the other one for y.__rop__(x), where
  404. only the arguments are inversed.
  405. """
  406. # The order is important, since the first one should be
  407. # left.__op__(right).
  408. for arg in (right, left):
  409. new_context = context.clone()
  410. new_context.callcontext = CallContext(args=[arg])
  411. new_context.boundnode = None
  412. yield new_context
  413. @staticmethod
  414. def _same_type(type1, type2) -> bool:
  415. """Check if type1 is the same as type2."""
  416. return type1.qname() == type2.qname()
  417. @staticmethod
  418. def _get_aug_flow(
  419. left: InferenceResult,
  420. left_type: InferenceResult | None,
  421. aug_opnode: nodes.AugAssign,
  422. right: InferenceResult,
  423. right_type: InferenceResult | None,
  424. context: InferenceContext,
  425. reverse_context: InferenceContext,
  426. ) -> list[partial[Generator[InferenceResult]]]:
  427. """Get the flow for augmented binary operations.
  428. The rules are a bit messy:
  429. * if left and right have the same type, then left.__augop__(right)
  430. is first tried and then left.__op__(right).
  431. * if left and right are unrelated typewise, then
  432. left.__augop__(right) is tried, then left.__op__(right)
  433. is tried and then right.__rop__(left) is tried.
  434. * if left is a subtype of right, then left.__augop__(right)
  435. is tried and then left.__op__(right).
  436. * if left is a supertype of right, then left.__augop__(right)
  437. is tried, then right.__rop__(left) and then
  438. left.__op__(right)
  439. """
  440. from astroid import helpers # pylint: disable=import-outside-toplevel
  441. bin_op = aug_opnode.op.strip("=")
  442. aug_op = aug_opnode.op
  443. if OperatorNode._same_type(left_type, right_type):
  444. methods = [
  445. OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
  446. OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
  447. ]
  448. elif helpers.is_subtype(left_type, right_type):
  449. methods = [
  450. OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
  451. OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
  452. ]
  453. elif helpers.is_supertype(left_type, right_type):
  454. methods = [
  455. OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
  456. OperatorNode._bin_op(
  457. right, aug_opnode, bin_op, left, reverse_context, reverse=True
  458. ),
  459. OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
  460. ]
  461. else:
  462. methods = [
  463. OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
  464. OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
  465. OperatorNode._bin_op(
  466. right, aug_opnode, bin_op, left, reverse_context, reverse=True
  467. ),
  468. ]
  469. return methods
  470. @staticmethod
  471. def _get_binop_flow(
  472. left: InferenceResult,
  473. left_type: InferenceResult | None,
  474. binary_opnode: nodes.AugAssign | nodes.BinOp,
  475. right: InferenceResult,
  476. right_type: InferenceResult | None,
  477. context: InferenceContext,
  478. reverse_context: InferenceContext,
  479. ) -> list[partial[Generator[InferenceResult]]]:
  480. """Get the flow for binary operations.
  481. The rules are a bit messy:
  482. * if left and right have the same type, then only one
  483. method will be called, left.__op__(right)
  484. * if left and right are unrelated typewise, then first
  485. left.__op__(right) is tried and if this does not exist
  486. or returns NotImplemented, then right.__rop__(left) is tried.
  487. * if left is a subtype of right, then only left.__op__(right)
  488. is tried.
  489. * if left is a supertype of right, then right.__rop__(left)
  490. is first tried and then left.__op__(right)
  491. """
  492. from astroid import helpers # pylint: disable=import-outside-toplevel
  493. op = binary_opnode.op
  494. if OperatorNode._same_type(left_type, right_type):
  495. methods = [OperatorNode._bin_op(left, binary_opnode, op, right, context)]
  496. elif helpers.is_subtype(left_type, right_type):
  497. methods = [OperatorNode._bin_op(left, binary_opnode, op, right, context)]
  498. elif helpers.is_supertype(left_type, right_type):
  499. methods = [
  500. OperatorNode._bin_op(
  501. right, binary_opnode, op, left, reverse_context, reverse=True
  502. ),
  503. OperatorNode._bin_op(left, binary_opnode, op, right, context),
  504. ]
  505. else:
  506. methods = [
  507. OperatorNode._bin_op(left, binary_opnode, op, right, context),
  508. OperatorNode._bin_op(
  509. right, binary_opnode, op, left, reverse_context, reverse=True
  510. ),
  511. ]
  512. # pylint: disable = too-many-boolean-expressions
  513. if (
  514. op == "|"
  515. and (
  516. isinstance(left, (bases.UnionType, nodes.ClassDef))
  517. or (isinstance(left, nodes.Const) and left.value is None)
  518. )
  519. and (
  520. isinstance(right, (bases.UnionType, nodes.ClassDef))
  521. or (isinstance(right, nodes.Const) and right.value is None)
  522. )
  523. ):
  524. methods.extend([partial(OperatorNode._bin_op_or_union_type, left, right)])
  525. return methods
  526. @staticmethod
  527. def _infer_binary_operation(
  528. left: InferenceResult,
  529. right: InferenceResult,
  530. binary_opnode: nodes.AugAssign | nodes.BinOp,
  531. context: InferenceContext,
  532. flow_factory: GetFlowFactory,
  533. ) -> Generator[InferenceResult | util.BadBinaryOperationMessage]:
  534. """Infer a binary operation between a left operand and a right operand.
  535. This is used by both normal binary operations and augmented binary
  536. operations, the only difference is the flow factory used.
  537. """
  538. from astroid import helpers # pylint: disable=import-outside-toplevel
  539. context, reverse_context = OperatorNode._get_binop_contexts(
  540. context, left, right
  541. )
  542. left_type = helpers.object_type(left)
  543. right_type = helpers.object_type(right)
  544. methods = flow_factory(
  545. left, left_type, binary_opnode, right, right_type, context, reverse_context
  546. )
  547. for method in methods:
  548. try:
  549. results = list(method())
  550. except AttributeError:
  551. continue
  552. except AttributeInferenceError:
  553. continue
  554. except InferenceError:
  555. yield util.Uninferable
  556. return
  557. else:
  558. if any(isinstance(result, util.UninferableBase) for result in results):
  559. yield util.Uninferable
  560. return
  561. if all(map(OperatorNode._is_not_implemented, results)):
  562. continue
  563. not_implemented = sum(
  564. 1 for result in results if OperatorNode._is_not_implemented(result)
  565. )
  566. if not_implemented and not_implemented != len(results):
  567. # Can't infer yet what this is.
  568. yield util.Uninferable
  569. return
  570. yield from results
  571. return
  572. # The operation doesn't seem to be supported so let the caller know about it
  573. yield util.BadBinaryOperationMessage(left_type, binary_opnode.op, right_type)