util.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # Copyright 2016 Grist Labs, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # https://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import ast
  15. import collections
  16. import io
  17. import sys
  18. import token
  19. import tokenize
  20. from abc import ABCMeta
  21. from ast import Module, expr, AST
  22. from functools import lru_cache
  23. from typing import (
  24. Callable,
  25. Dict,
  26. Iterable,
  27. Iterator,
  28. List,
  29. Optional,
  30. Tuple,
  31. Union,
  32. cast,
  33. Any,
  34. TYPE_CHECKING,
  35. Type,
  36. )
  37. if TYPE_CHECKING: # pragma: no cover
  38. from .astroid_compat import NodeNG
  39. # Type class used to expand out the definition of AST to include fields added by this library
  40. # It's not actually used for anything other than type checking though!
  41. class EnhancedAST(AST):
  42. # Additional attributes set by mark_tokens
  43. first_token = None # type: Token
  44. last_token = None # type: Token
  45. lineno = 0 # type: int
  46. end_lineno = 0 # type : int
  47. end_col_offset = 0 # type : int
  48. AstNode = Union[EnhancedAST, NodeNG]
  49. TokenInfo = tokenize.TokenInfo
  50. def token_repr(tok_type, string):
  51. # type: (int, Optional[str]) -> str
  52. """Returns a human-friendly representation of a token with the given type and string."""
  53. # repr() prefixes unicode with 'u' on Python2 but not Python3; strip it out for consistency.
  54. return '%s:%s' % (token.tok_name[tok_type], repr(string).lstrip('u'))
  55. class Token(collections.namedtuple('Token', 'type string start end line index startpos endpos')):
  56. """
  57. TokenInfo is an 8-tuple containing the same 5 fields as the tokens produced by the tokenize
  58. module, and 3 additional ones useful for this module:
  59. - [0] .type Token type (see token.py)
  60. - [1] .string Token (a string)
  61. - [2] .start Starting (row, column) indices of the token (a 2-tuple of ints)
  62. - [3] .end Ending (row, column) indices of the token (a 2-tuple of ints)
  63. - [4] .line Original line (string)
  64. - [5] .index Index of the token in the list of tokens that it belongs to.
  65. - [6] .startpos Starting character offset into the input text.
  66. - [7] .endpos Ending character offset into the input text.
  67. """
  68. def __str__(self):
  69. # type: () -> str
  70. return token_repr(self.type, self.string)
  71. def match_token(token, tok_type, tok_str=None):
  72. # type: (Token, int, Optional[str]) -> bool
  73. """Returns true if token is of the given type and, if a string is given, has that string."""
  74. return token.type == tok_type and (tok_str is None or token.string == tok_str)
  75. def expect_token(token, tok_type, tok_str=None):
  76. # type: (Token, int, Optional[str]) -> None
  77. """
  78. Verifies that the given token is of the expected type. If tok_str is given, the token string
  79. is verified too. If the token doesn't match, raises an informative ValueError.
  80. """
  81. if not match_token(token, tok_type, tok_str):
  82. raise ValueError("Expected token %s, got %s on line %s col %s" % (
  83. token_repr(tok_type, tok_str), str(token),
  84. token.start[0], token.start[1] + 1))
  85. def is_non_coding_token(token_type):
  86. # type: (int) -> bool
  87. """
  88. These are considered non-coding tokens, as they don't affect the syntax tree.
  89. """
  90. return token_type in (token.NL, token.COMMENT, token.ENCODING)
  91. def generate_tokens(text):
  92. # type: (str) -> Iterator[TokenInfo]
  93. """
  94. Generates standard library tokens for the given code.
  95. """
  96. # tokenize.generate_tokens is technically an undocumented API for Python3, but allows us to use the same API as for
  97. # Python2. See https://stackoverflow.com/a/4952291/328565.
  98. # FIXME: Remove cast once https://github.com/python/typeshed/issues/7003 gets fixed
  99. return tokenize.generate_tokens(cast(Callable[[], str], io.StringIO(text).readline))
  100. def iter_children_func(node):
  101. # type: (AST) -> Callable
  102. """
  103. Returns a function which yields all direct children of a AST node,
  104. skipping children that are singleton nodes.
  105. The function depends on whether ``node`` is from ``ast`` or from the ``astroid`` module.
  106. """
  107. return iter_children_astroid if hasattr(node, 'get_children') else iter_children_ast
  108. def iter_children_astroid(node, include_joined_str=False):
  109. # type: (NodeNG, bool) -> Union[Iterator, List]
  110. if not include_joined_str and is_joined_str(node):
  111. return []
  112. return node.get_children()
  113. SINGLETONS = {c for n, c in ast.__dict__.items() if isinstance(c, type) and
  114. issubclass(c, (ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop))}
  115. def iter_children_ast(node, include_joined_str=False):
  116. # type: (AST, bool) -> Iterator[Union[AST, expr]]
  117. if not include_joined_str and is_joined_str(node):
  118. return
  119. if isinstance(node, ast.Dict):
  120. # override the iteration order: instead of <all keys>, <all values>,
  121. # yield keys and values in source order (key1, value1, key2, value2, ...)
  122. for (key, value) in zip(node.keys, node.values):
  123. if key is not None:
  124. yield key
  125. yield value
  126. return
  127. for child in ast.iter_child_nodes(node):
  128. # Skip singleton children; they don't reflect particular positions in the code and break the
  129. # assumptions about the tree consisting of distinct nodes. Note that collecting classes
  130. # beforehand and checking them in a set is faster than using isinstance each time.
  131. if child.__class__ not in SINGLETONS:
  132. yield child
  133. stmt_class_names = {n for n, c in ast.__dict__.items()
  134. if isinstance(c, type) and issubclass(c, ast.stmt)}
  135. expr_class_names = ({n for n, c in ast.__dict__.items()
  136. if isinstance(c, type) and issubclass(c, ast.expr)} |
  137. {'AssignName', 'DelName', 'Const', 'AssignAttr', 'DelAttr'})
  138. # These feel hacky compared to isinstance() but allow us to work with both ast and astroid nodes
  139. # in the same way, and without even importing astroid.
  140. def is_expr(node):
  141. # type: (AstNode) -> bool
  142. """Returns whether node is an expression node."""
  143. return node.__class__.__name__ in expr_class_names
  144. def is_stmt(node):
  145. # type: (AstNode) -> bool
  146. """Returns whether node is a statement node."""
  147. return node.__class__.__name__ in stmt_class_names
  148. def is_module(node):
  149. # type: (AstNode) -> bool
  150. """Returns whether node is a module node."""
  151. return node.__class__.__name__ == 'Module'
  152. def is_joined_str(node):
  153. # type: (AstNode) -> bool
  154. """Returns whether node is a JoinedStr node, used to represent f-strings."""
  155. # At the moment, nodes below JoinedStr have wrong line/col info, and trying to process them only
  156. # leads to errors.
  157. return node.__class__.__name__ == 'JoinedStr'
  158. def is_expr_stmt(node):
  159. # type: (AstNode) -> bool
  160. """Returns whether node is an `Expr` node, which is a statement that is an expression."""
  161. return node.__class__.__name__ == 'Expr'
  162. CONSTANT_CLASSES: Tuple[Type, ...] = (ast.Constant,)
  163. try:
  164. from astroid.nodes import Const
  165. CONSTANT_CLASSES += (Const,)
  166. except ImportError: # pragma: no cover
  167. # astroid is not available
  168. pass
  169. def is_constant(node):
  170. # type: (AstNode) -> bool
  171. """Returns whether node is a Constant node."""
  172. return isinstance(node, CONSTANT_CLASSES)
  173. def is_ellipsis(node):
  174. # type: (AstNode) -> bool
  175. """Returns whether node is an Ellipsis node."""
  176. return is_constant(node) and node.value is Ellipsis # type: ignore
  177. def is_starred(node):
  178. # type: (AstNode) -> bool
  179. """Returns whether node is a starred expression node."""
  180. return node.__class__.__name__ == 'Starred'
  181. def is_slice(node):
  182. # type: (AstNode) -> bool
  183. """Returns whether node represents a slice, e.g. `1:2` in `x[1:2]`"""
  184. # Before 3.9, a tuple containing a slice is an ExtSlice,
  185. # but this was removed in https://bugs.python.org/issue34822
  186. return (
  187. node.__class__.__name__ in ('Slice', 'ExtSlice')
  188. or (
  189. node.__class__.__name__ == 'Tuple'
  190. and any(map(is_slice, cast(ast.Tuple, node).elts))
  191. )
  192. )
  193. def is_empty_astroid_slice(node):
  194. # type: (AstNode) -> bool
  195. return (
  196. node.__class__.__name__ == "Slice"
  197. and not isinstance(node, ast.AST)
  198. and node.lower is node.upper is node.step is None
  199. )
  200. # Sentinel value used by visit_tree().
  201. _PREVISIT = object()
  202. def visit_tree(node, previsit, postvisit):
  203. # type: (Module, Callable[[AstNode, Optional[Token]], Tuple[Optional[Token], Optional[Token]]], Optional[Callable[[AstNode, Optional[Token], Optional[Token]], None]]) -> None
  204. """
  205. Scans the tree under the node depth-first using an explicit stack. It avoids implicit recursion
  206. via the function call stack to avoid hitting 'maximum recursion depth exceeded' error.
  207. It calls ``previsit()`` and ``postvisit()`` as follows:
  208. * ``previsit(node, par_value)`` - should return ``(par_value, value)``
  209. ``par_value`` is as returned from ``previsit()`` of the parent.
  210. * ``postvisit(node, par_value, value)`` - should return ``value``
  211. ``par_value`` is as returned from ``previsit()`` of the parent, and ``value`` is as
  212. returned from ``previsit()`` of this node itself. The return ``value`` is ignored except
  213. the one for the root node, which is returned from the overall ``visit_tree()`` call.
  214. For the initial node, ``par_value`` is None. ``postvisit`` may be None.
  215. """
  216. if not postvisit:
  217. postvisit = lambda node, pvalue, value: None
  218. iter_children = iter_children_func(node)
  219. done = set()
  220. ret = None
  221. stack = [(node, None, _PREVISIT)] # type: List[Tuple[AstNode, Optional[Token], Union[Optional[Token], object]]]
  222. while stack:
  223. current, par_value, value = stack.pop()
  224. if value is _PREVISIT:
  225. assert current not in done # protect againt infinite loop in case of a bad tree.
  226. done.add(current)
  227. pvalue, post_value = previsit(current, par_value)
  228. stack.append((current, par_value, post_value))
  229. # Insert all children in reverse order (so that first child ends up on top of the stack).
  230. ins = len(stack)
  231. for n in iter_children(current):
  232. stack.insert(ins, (n, pvalue, _PREVISIT))
  233. else:
  234. ret = postvisit(current, par_value, cast(Optional[Token], value))
  235. return ret
  236. def walk(node, include_joined_str=False):
  237. # type: (AST, bool) -> Iterator[Union[Module, AstNode]]
  238. """
  239. Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node``
  240. itself), using depth-first pre-order traversal (yieling parents before their children).
  241. This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and
  242. ``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``.
  243. By default, ``JoinedStr`` (f-string) nodes and their contents are skipped
  244. because they previously couldn't be handled. Set ``include_joined_str`` to True to include them.
  245. """
  246. iter_children = iter_children_func(node)
  247. done = set()
  248. stack = [node]
  249. while stack:
  250. current = stack.pop()
  251. assert current not in done # protect againt infinite loop in case of a bad tree.
  252. done.add(current)
  253. yield current
  254. # Insert all children in reverse order (so that first child ends up on top of the stack).
  255. # This is faster than building a list and reversing it.
  256. ins = len(stack)
  257. for c in iter_children(current, include_joined_str):
  258. stack.insert(ins, c)
  259. def replace(text, replacements):
  260. # type: (str, List[Tuple[int, int, str]]) -> str
  261. """
  262. Replaces multiple slices of text with new values. This is a convenience method for making code
  263. modifications of ranges e.g. as identified by ``ASTTokens.get_text_range(node)``. Replacements is
  264. an iterable of ``(start, end, new_text)`` tuples.
  265. For example, ``replace("this is a test", [(0, 4, "X"), (8, 9, "THE")])`` produces
  266. ``"X is THE test"``.
  267. """
  268. p = 0
  269. parts = []
  270. for (start, end, new_text) in sorted(replacements):
  271. parts.append(text[p:start])
  272. parts.append(new_text)
  273. p = end
  274. parts.append(text[p:])
  275. return ''.join(parts)
  276. class NodeMethods:
  277. """
  278. Helper to get `visit_{node_type}` methods given a node's class and cache the results.
  279. """
  280. def __init__(self):
  281. # type: () -> None
  282. self._cache = {} # type: Dict[Union[ABCMeta, type], Callable[[AstNode, Token, Token], Tuple[Token, Token]]]
  283. def get(self, obj, cls):
  284. # type: (Any, Union[ABCMeta, type]) -> Callable
  285. """
  286. Using the lowercase name of the class as node_type, returns `obj.visit_{node_type}`,
  287. or `obj.visit_default` if the type-specific method is not found.
  288. """
  289. method = self._cache.get(cls)
  290. if not method:
  291. name = "visit_" + cls.__name__.lower()
  292. method = getattr(obj, name, obj.visit_default)
  293. self._cache[cls] = method
  294. return method
  295. def patched_generate_tokens(original_tokens):
  296. # type: (Iterable[TokenInfo]) -> Iterator[TokenInfo]
  297. """
  298. Fixes tokens yielded by `tokenize.generate_tokens` to handle more non-ASCII characters in identifiers.
  299. Workaround for https://github.com/python/cpython/issues/68382.
  300. Should only be used when tokenizing a string that is known to be valid syntax,
  301. because it assumes that error tokens are not actually errors.
  302. Combines groups of consecutive NAME, NUMBER, and/or ERRORTOKEN tokens into a single NAME token.
  303. """
  304. group = [] # type: List[tokenize.TokenInfo]
  305. for tok in original_tokens:
  306. if (
  307. tok.type in (tokenize.NAME, tokenize.ERRORTOKEN, tokenize.NUMBER)
  308. # Only combine tokens if they have no whitespace in between
  309. and (not group or group[-1].end == tok.start)
  310. ):
  311. group.append(tok)
  312. else:
  313. for combined_token in combine_tokens(group):
  314. yield combined_token
  315. group = []
  316. yield tok
  317. for combined_token in combine_tokens(group):
  318. yield combined_token
  319. def combine_tokens(group):
  320. # type: (List[tokenize.TokenInfo]) -> List[tokenize.TokenInfo]
  321. if not any(tok.type == tokenize.ERRORTOKEN for tok in group) or len({tok.line for tok in group}) != 1:
  322. return group
  323. return [
  324. tokenize.TokenInfo(
  325. type=tokenize.NAME,
  326. string="".join(t.string for t in group),
  327. start=group[0].start,
  328. end=group[-1].end,
  329. line=group[0].line,
  330. )
  331. ]
  332. def last_stmt(node):
  333. # type: (AstNode) -> AstNode
  334. """
  335. If the given AST node contains multiple statements, return the last one.
  336. Otherwise, just return the node.
  337. """
  338. child_stmts = [
  339. child for child in iter_children_func(node)(node)
  340. if is_stmt(child) or type(child).__name__ in (
  341. "excepthandler",
  342. "ExceptHandler",
  343. "match_case",
  344. "MatchCase",
  345. "TryExcept",
  346. "TryFinally",
  347. )
  348. ]
  349. if child_stmts:
  350. return last_stmt(child_stmts[-1])
  351. return node
  352. @lru_cache(maxsize=None)
  353. def fstring_positions_work():
  354. # type: () -> bool
  355. """
  356. The positions attached to nodes inside f-string FormattedValues have some bugs
  357. that were fixed in Python 3.9.7 in https://github.com/python/cpython/pull/27729.
  358. This checks for those bugs more concretely without relying on the Python version.
  359. Specifically this checks:
  360. - Values with a format spec or conversion
  361. - Repeated (i.e. identical-looking) expressions
  362. - f-strings implicitly concatenated over multiple lines.
  363. - Multiline, triple-quoted f-strings.
  364. """
  365. source = """(
  366. f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}"
  367. f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}"
  368. f"{x + y + z} {x} {y} {z} {z} {z!a} {z:z}"
  369. f'''
  370. {s} {t}
  371. {u} {v}
  372. '''
  373. )"""
  374. tree = ast.parse(source)
  375. name_nodes = [node for node in ast.walk(tree) if isinstance(node, ast.Name)]
  376. name_positions = [(node.lineno, node.col_offset) for node in name_nodes]
  377. positions_are_unique = len(set(name_positions)) == len(name_positions)
  378. correct_source_segments = all(
  379. ast.get_source_segment(source, node) == node.id
  380. for node in name_nodes
  381. )
  382. return positions_are_unique and correct_source_segments
  383. def annotate_fstring_nodes(tree):
  384. # type: (ast.AST) -> None
  385. """
  386. Add a special attribute `_broken_positions` to nodes inside f-strings
  387. if the lineno/col_offset cannot be trusted.
  388. """
  389. if sys.version_info >= (3, 12):
  390. # f-strings were weirdly implemented until https://peps.python.org/pep-0701/
  391. # In Python 3.12, inner nodes have sensible positions.
  392. return
  393. for joinedstr in walk(tree, include_joined_str=True):
  394. if not isinstance(joinedstr, ast.JoinedStr):
  395. continue
  396. for part in joinedstr.values:
  397. # The ast positions of the FormattedValues/Constant nodes span the full f-string, which is weird.
  398. setattr(part, '_broken_positions', True) # use setattr for mypy
  399. if isinstance(part, ast.FormattedValue):
  400. if not fstring_positions_work():
  401. for child in walk(part.value):
  402. setattr(child, '_broken_positions', True)
  403. if part.format_spec: # this is another JoinedStr
  404. # Again, the standard positions span the full f-string.
  405. setattr(part.format_spec, '_broken_positions', True)