| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487 |
- # Copyright 2016 Grist Labs, Inc.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # https://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import ast
- import collections
- import io
- import sys
- import token
- import tokenize
- from abc import ABCMeta
- from ast import Module, expr, AST
- from functools import lru_cache
- from typing import (
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Tuple,
- Union,
- cast,
- Any,
- TYPE_CHECKING,
- Type,
- )
- if TYPE_CHECKING: # pragma: no cover
- from .astroid_compat import NodeNG
- # Type class used to expand out the definition of AST to include fields added by this library
- # It's not actually used for anything other than type checking though!
- class EnhancedAST(AST):
- # Additional attributes set by mark_tokens
- first_token = None # type: Token
- last_token = None # type: Token
- lineno = 0 # type: int
- end_lineno = 0 # type : int
- end_col_offset = 0 # type : int
- AstNode = Union[EnhancedAST, NodeNG]
- TokenInfo = tokenize.TokenInfo
- def token_repr(tok_type, string):
- # type: (int, Optional[str]) -> str
- """Returns a human-friendly representation of a token with the given type and string."""
- # repr() prefixes unicode with 'u' on Python2 but not Python3; strip it out for consistency.
- return '%s:%s' % (token.tok_name[tok_type], repr(string).lstrip('u'))
- class Token(collections.namedtuple('Token', 'type string start end line index startpos endpos')):
- """
- TokenInfo is an 8-tuple containing the same 5 fields as the tokens produced by the tokenize
- module, and 3 additional ones useful for this module:
- - [0] .type Token type (see token.py)
- - [1] .string Token (a string)
- - [2] .start Starting (row, column) indices of the token (a 2-tuple of ints)
- - [3] .end Ending (row, column) indices of the token (a 2-tuple of ints)
- - [4] .line Original line (string)
- - [5] .index Index of the token in the list of tokens that it belongs to.
- - [6] .startpos Starting character offset into the input text.
- - [7] .endpos Ending character offset into the input text.
- """
- def __str__(self):
- # type: () -> str
- return token_repr(self.type, self.string)
- def match_token(token, tok_type, tok_str=None):
- # type: (Token, int, Optional[str]) -> bool
- """Returns true if token is of the given type and, if a string is given, has that string."""
- return token.type == tok_type and (tok_str is None or token.string == tok_str)
- def expect_token(token, tok_type, tok_str=None):
- # type: (Token, int, Optional[str]) -> None
- """
- Verifies that the given token is of the expected type. If tok_str is given, the token string
- is verified too. If the token doesn't match, raises an informative ValueError.
- """
- if not match_token(token, tok_type, tok_str):
- raise ValueError("Expected token %s, got %s on line %s col %s" % (
- token_repr(tok_type, tok_str), str(token),
- token.start[0], token.start[1] + 1))
- def is_non_coding_token(token_type):
- # type: (int) -> bool
- """
- These are considered non-coding tokens, as they don't affect the syntax tree.
- """
- return token_type in (token.NL, token.COMMENT, token.ENCODING)
- def generate_tokens(text):
- # type: (str) -> Iterator[TokenInfo]
- """
- Generates standard library tokens for the given code.
- """
- # tokenize.generate_tokens is technically an undocumented API for Python3, but allows us to use the same API as for
- # Python2. See https://stackoverflow.com/a/4952291/328565.
- # FIXME: Remove cast once https://github.com/python/typeshed/issues/7003 gets fixed
- return tokenize.generate_tokens(cast(Callable[[], str], io.StringIO(text).readline))
- def iter_children_func(node):
- # type: (AST) -> Callable
- """
- Returns a function which yields all direct children of a AST node,
- skipping children that are singleton nodes.
- The function depends on whether ``node`` is from ``ast`` or from the ``astroid`` module.
- """
- return iter_children_astroid if hasattr(node, 'get_children') else iter_children_ast
- def iter_children_astroid(node, include_joined_str=False):
- # type: (NodeNG, bool) -> Union[Iterator, List]
- if not include_joined_str and is_joined_str(node):
- return []
- return node.get_children()
- SINGLETONS = {c for n, c in ast.__dict__.items() if isinstance(c, type) and
- issubclass(c, (ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop))}
- def iter_children_ast(node, include_joined_str=False):
- # type: (AST, bool) -> Iterator[Union[AST, expr]]
- if not include_joined_str and is_joined_str(node):
- return
- if isinstance(node, ast.Dict):
- # override the iteration order: instead of <all keys>, <all values>,
- # yield keys and values in source order (key1, value1, key2, value2, ...)
- for (key, value) in zip(node.keys, node.values):
- if key is not None:
- yield key
- yield value
- return
- for child in ast.iter_child_nodes(node):
- # Skip singleton children; they don't reflect particular positions in the code and break the
- # assumptions about the tree consisting of distinct nodes. Note that collecting classes
- # beforehand and checking them in a set is faster than using isinstance each time.
- if child.__class__ not in SINGLETONS:
- yield child
- stmt_class_names = {n for n, c in ast.__dict__.items()
- if isinstance(c, type) and issubclass(c, ast.stmt)}
- expr_class_names = ({n for n, c in ast.__dict__.items()
- if isinstance(c, type) and issubclass(c, ast.expr)} |
- {'AssignName', 'DelName', 'Const', 'AssignAttr', 'DelAttr'})
- # These feel hacky compared to isinstance() but allow us to work with both ast and astroid nodes
- # in the same way, and without even importing astroid.
- def is_expr(node):
- # type: (AstNode) -> bool
- """Returns whether node is an expression node."""
- return node.__class__.__name__ in expr_class_names
- def is_stmt(node):
- # type: (AstNode) -> bool
- """Returns whether node is a statement node."""
- return node.__class__.__name__ in stmt_class_names
- def is_module(node):
- # type: (AstNode) -> bool
- """Returns whether node is a module node."""
- return node.__class__.__name__ == 'Module'
- def is_joined_str(node):
- # type: (AstNode) -> bool
- """Returns whether node is a JoinedStr node, used to represent f-strings."""
- # At the moment, nodes below JoinedStr have wrong line/col info, and trying to process them only
- # leads to errors.
- return node.__class__.__name__ == 'JoinedStr'
- def is_expr_stmt(node):
- # type: (AstNode) -> bool
- """Returns whether node is an `Expr` node, which is a statement that is an expression."""
- return node.__class__.__name__ == 'Expr'
- CONSTANT_CLASSES: Tuple[Type, ...] = (ast.Constant,)
- try:
- from astroid.nodes import Const
- CONSTANT_CLASSES += (Const,)
- except ImportError: # pragma: no cover
- # astroid is not available
- pass
- def is_constant(node):
- # type: (AstNode) -> bool
- """Returns whether node is a Constant node."""
- return isinstance(node, CONSTANT_CLASSES)
- def is_ellipsis(node):
- # type: (AstNode) -> bool
- """Returns whether node is an Ellipsis node."""
- return is_constant(node) and node.value is Ellipsis # type: ignore
- def is_starred(node):
- # type: (AstNode) -> bool
- """Returns whether node is a starred expression node."""
- return node.__class__.__name__ == 'Starred'
- def is_slice(node):
- # type: (AstNode) -> bool
- """Returns whether node represents a slice, e.g. `1:2` in `x[1:2]`"""
- # Before 3.9, a tuple containing a slice is an ExtSlice,
- # but this was removed in https://bugs.python.org/issue34822
- return (
- node.__class__.__name__ in ('Slice', 'ExtSlice')
- or (
- node.__class__.__name__ == 'Tuple'
- and any(map(is_slice, cast(ast.Tuple, node).elts))
- )
- )
- def is_empty_astroid_slice(node):
- # type: (AstNode) -> bool
- return (
- node.__class__.__name__ == "Slice"
- and not isinstance(node, ast.AST)
- and node.lower is node.upper is node.step is None
- )
- # Sentinel value used by visit_tree().
- _PREVISIT = object()
- def visit_tree(node, previsit, postvisit):
- # type: (Module, Callable[[AstNode, Optional[Token]], Tuple[Optional[Token], Optional[Token]]], Optional[Callable[[AstNode, Optional[Token], Optional[Token]], None]]) -> None
- """
- Scans the tree under the node depth-first using an explicit stack. It avoids implicit recursion
- via the function call stack to avoid hitting 'maximum recursion depth exceeded' error.
- It calls ``previsit()`` and ``postvisit()`` as follows:
- * ``previsit(node, par_value)`` - should return ``(par_value, value)``
- ``par_value`` is as returned from ``previsit()`` of the parent.
- * ``postvisit(node, par_value, value)`` - should return ``value``
- ``par_value`` is as returned from ``previsit()`` of the parent, and ``value`` is as
- returned from ``previsit()`` of this node itself. The return ``value`` is ignored except
- the one for the root node, which is returned from the overall ``visit_tree()`` call.
- For the initial node, ``par_value`` is None. ``postvisit`` may be None.
- """
- if not postvisit:
- postvisit = lambda node, pvalue, value: None
- iter_children = iter_children_func(node)
- done = set()
- ret = None
- stack = [(node, None, _PREVISIT)] # type: List[Tuple[AstNode, Optional[Token], Union[Optional[Token], object]]]
- while stack:
- current, par_value, value = stack.pop()
- if value is _PREVISIT:
- assert current not in done # protect againt infinite loop in case of a bad tree.
- done.add(current)
- pvalue, post_value = previsit(current, par_value)
- stack.append((current, par_value, post_value))
- # Insert all children in reverse order (so that first child ends up on top of the stack).
- ins = len(stack)
- for n in iter_children(current):
- stack.insert(ins, (n, pvalue, _PREVISIT))
- else:
- ret = postvisit(current, par_value, cast(Optional[Token], value))
- return ret
- def walk(node, include_joined_str=False):
- # type: (AST, bool) -> Iterator[Union[Module, AstNode]]
- """
- Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node``
- itself), using depth-first pre-order traversal (yieling parents before their children).
- This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and
- ``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``.
- By default, ``JoinedStr`` (f-string) nodes and their contents are skipped
- because they previously couldn't be handled. Set ``include_joined_str`` to True to include them.
- """
- iter_children = iter_children_func(node)
- done = set()
- stack = [node]
- while stack:
- current = stack.pop()
- assert current not in done # protect againt infinite loop in case of a bad tree.
- done.add(current)
- yield current
- # Insert all children in reverse order (so that first child ends up on top of the stack).
- # This is faster than building a list and reversing it.
- ins = len(stack)
- for c in iter_children(current, include_joined_str):
- stack.insert(ins, c)
- def replace(text, replacements):
- # type: (str, List[Tuple[int, int, str]]) -> str
- """
- Replaces multiple slices of text with new values. This is a convenience method for making code
- modifications of ranges e.g. as identified by ``ASTTokens.get_text_range(node)``. Replacements is
- an iterable of ``(start, end, new_text)`` tuples.
- For example, ``replace("this is a test", [(0, 4, "X"), (8, 9, "THE")])`` produces
- ``"X is THE test"``.
- """
- p = 0
- parts = []
- for (start, end, new_text) in sorted(replacements):
- parts.append(text[p:start])
- parts.append(new_text)
- p = end
- parts.append(text[p:])
- return ''.join(parts)
- class NodeMethods:
- """
- Helper to get `visit_{node_type}` methods given a node's class and cache the results.
- """
- def __init__(self):
- # type: () -> None
- self._cache = {} # type: Dict[Union[ABCMeta, type], Callable[[AstNode, Token, Token], Tuple[Token, Token]]]
- def get(self, obj, cls):
- # type: (Any, Union[ABCMeta, type]) -> Callable
- """
- Using the lowercase name of the class as node_type, returns `obj.visit_{node_type}`,
- or `obj.visit_default` if the type-specific method is not found.
- """
- method = self._cache.get(cls)
- if not method:
- name = "visit_" + cls.__name__.lower()
- method = getattr(obj, name, obj.visit_default)
- self._cache[cls] = method
- return method
- def patched_generate_tokens(original_tokens):
- # type: (Iterable[TokenInfo]) -> Iterator[TokenInfo]
- """
- Fixes tokens yielded by `tokenize.generate_tokens` to handle more non-ASCII characters in identifiers.
- Workaround for https://github.com/python/cpython/issues/68382.
- Should only be used when tokenizing a string that is known to be valid syntax,
- because it assumes that error tokens are not actually errors.
- Combines groups of consecutive NAME, NUMBER, and/or ERRORTOKEN tokens into a single NAME token.
- """
- group = [] # type: List[tokenize.TokenInfo]
- for tok in original_tokens:
- if (
- tok.type in (tokenize.NAME, tokenize.ERRORTOKEN, tokenize.NUMBER)
- # Only combine tokens if they have no whitespace in between
- and (not group or group[-1].end == tok.start)
- ):
- group.append(tok)
- else:
- for combined_token in combine_tokens(group):
- yield combined_token
- group = []
- yield tok
- for combined_token in combine_tokens(group):
- yield combined_token
- def combine_tokens(group):
- # type: (List[tokenize.TokenInfo]) -> List[tokenize.TokenInfo]
- if not any(tok.type == tokenize.ERRORTOKEN for tok in group) or len({tok.line for tok in group}) != 1:
- return group
- return [
- tokenize.TokenInfo(
- type=tokenize.NAME,
- string="".join(t.string for t in group),
- start=group[0].start,
- end=group[-1].end,
- line=group[0].line,
- )
- ]
- def last_stmt(node):
- # type: (AstNode) -> AstNode
- """
- If the given AST node contains multiple statements, return the last one.
- Otherwise, just return the node.
- """
- child_stmts = [
- child for child in iter_children_func(node)(node)
- if is_stmt(child) or type(child).__name__ in (
- "excepthandler",
- "ExceptHandler",
- "match_case",
- "MatchCase",
- "TryExcept",
- "TryFinally",
- )
- ]
- if child_stmts:
- return last_stmt(child_stmts[-1])
- return node
- @lru_cache(maxsize=None)
- def fstring_positions_work():
- # type: () -> bool
- """
- The positions attached to nodes inside f-string FormattedValues have some bugs
- that were fixed in Python 3.9.7 in https://github.com/python/cpython/pull/27729.
- This checks for those bugs more concretely without relying on the Python version.
- Specifically this checks:
- - Values with a format spec or conversion
- - Repeated (i.e. identical-looking) expressions
- - f-strings implicitly concatenated over multiple lines.
- - Multiline, triple-quoted f-strings.
- """
- source = """(
- f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}"
- f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}"
- f"{x + y + z} {x} {y} {z} {z} {z!a} {z:z}"
- f'''
- {s} {t}
- {u} {v}
- '''
- )"""
- tree = ast.parse(source)
- name_nodes = [node for node in ast.walk(tree) if isinstance(node, ast.Name)]
- name_positions = [(node.lineno, node.col_offset) for node in name_nodes]
- positions_are_unique = len(set(name_positions)) == len(name_positions)
- correct_source_segments = all(
- ast.get_source_segment(source, node) == node.id
- for node in name_nodes
- )
- return positions_are_unique and correct_source_segments
- def annotate_fstring_nodes(tree):
- # type: (ast.AST) -> None
- """
- Add a special attribute `_broken_positions` to nodes inside f-strings
- if the lineno/col_offset cannot be trusted.
- """
- if sys.version_info >= (3, 12):
- # f-strings were weirdly implemented until https://peps.python.org/pep-0701/
- # In Python 3.12, inner nodes have sensible positions.
- return
- for joinedstr in walk(tree, include_joined_str=True):
- if not isinstance(joinedstr, ast.JoinedStr):
- continue
- for part in joinedstr.values:
- # The ast positions of the FormattedValues/Constant nodes span the full f-string, which is weird.
- setattr(part, '_broken_positions', True) # use setattr for mypy
- if isinstance(part, ast.FormattedValue):
- if not fstring_positions_work():
- for child in walk(part.value):
- setattr(child, '_broken_positions', True)
- if part.format_spec: # this is another JoinedStr
- # Again, the standard positions span the full f-string.
- setattr(part.format_spec, '_broken_positions', True)
|