| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108 |
- """
- MIT License
- Copyright (c) 2021 Alex Hall
- Permission is hereby granted, free of charge, to any person obtaining a copy
- of this software and associated documentation files (the "Software"), to deal
- in the Software without restriction, including without limitation the rights
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- copies of the Software, and to permit persons to whom the Software is
- furnished to do so, subject to the following conditions:
- The above copyright notice and this permission notice shall be included in all
- copies or substantial portions of the Software.
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- SOFTWARE.
- """
- import __future__
- import ast
- import dis
- import inspect
- import io
- import linecache
- import re
- import sys
- import types
- from collections import defaultdict
- from copy import deepcopy
- from functools import lru_cache
- from itertools import islice
- from itertools import zip_longest
- from operator import attrgetter
- from pathlib import Path
- from threading import RLock
- from tokenize import detect_encoding
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Sized, Tuple, Type, TypeVar, Union, cast
- from ._utils import mangled_name,assert_, EnhancedAST,EnhancedInstruction,Instruction,get_instructions
- if TYPE_CHECKING: # pragma: no cover
- from asttokens import ASTTokens, ASTText
- from asttokens.asttokens import ASTTextBase
- function_node_types = (ast.FunctionDef, ast.AsyncFunctionDef) # type: Tuple[Type, ...]
- cache = lru_cache(maxsize=None)
- TESTING = 0
- class NotOneValueFound(Exception):
- def __init__(self,msg,values=[]):
- # type: (str, Sequence) -> None
- self.values=values
- super(NotOneValueFound,self).__init__(msg)
- T = TypeVar('T')
- def only(it):
- # type: (Iterable[T]) -> T
- if isinstance(it, Sized):
- if len(it) != 1:
- raise NotOneValueFound('Expected one value, found %s' % len(it))
- # noinspection PyTypeChecker
- return list(it)[0]
- lst = tuple(islice(it, 2))
- if len(lst) == 0:
- raise NotOneValueFound('Expected one value, found 0')
- if len(lst) > 1:
- raise NotOneValueFound('Expected one value, found several',lst)
- return lst[0]
- class Source(object):
- """
- The source code of a single file and associated metadata.
- The main method of interest is the classmethod `executing(frame)`.
- If you want an instance of this class, don't construct it.
- Ideally use the classmethod `for_frame(frame)`.
- If you don't have a frame, use `for_filename(filename [, module_globals])`.
- These methods cache instances by filename, so at most one instance exists per filename.
- Attributes:
- - filename
- - text
- - lines
- - tree: AST parsed from text, or None if text is not valid Python
- All nodes in the tree have an extra `parent` attribute
- Other methods of interest:
- - statements_at_line
- - asttokens
- - code_qualname
- """
- def __init__(self, filename, lines):
- # type: (str, Sequence[str]) -> None
- """
- Don't call this constructor, see the class docstring.
- """
- self.filename = filename
- self.text = ''.join(lines)
- self.lines = [line.rstrip('\r\n') for line in lines]
- self._nodes_by_line = defaultdict(list)
- self.tree = None
- self._qualnames = {}
- self._asttokens = None # type: Optional[ASTTokens]
- self._asttext = None # type: Optional[ASTText]
- try:
- self.tree = ast.parse(self.text, filename=filename)
- except (SyntaxError, ValueError):
- pass
- else:
- for node in ast.walk(self.tree):
- for child in ast.iter_child_nodes(node):
- cast(EnhancedAST, child).parent = cast(EnhancedAST, node)
- for lineno in node_linenos(node):
- self._nodes_by_line[lineno].append(node)
- visitor = QualnameVisitor()
- visitor.visit(self.tree)
- self._qualnames = visitor.qualnames
- @classmethod
- def for_frame(cls, frame, use_cache=True):
- # type: (types.FrameType, bool) -> "Source"
- """
- Returns the `Source` object corresponding to the file the frame is executing in.
- """
- return cls.for_filename(frame.f_code.co_filename, frame.f_globals or {}, use_cache)
- @classmethod
- def for_filename(
- cls,
- filename,
- module_globals=None,
- use_cache=True, # noqa no longer used
- ):
- # type: (Union[str, Path], Optional[Dict[str, Any]], bool) -> "Source"
- if isinstance(filename, Path):
- filename = str(filename)
- def get_lines():
- # type: () -> List[str]
- return linecache.getlines(cast(str, filename), module_globals)
- # Save the current linecache entry, then ensure the cache is up to date.
- entry = linecache.cache.get(filename) # type: ignore[attr-defined]
- linecache.checkcache(filename)
- lines = get_lines()
- if entry is not None and not lines:
- # There was an entry, checkcache removed it, and nothing replaced it.
- # This means the file wasn't simply changed (because the `lines` wouldn't be empty)
- # but rather the file was found not to exist, probably because `filename` was fake.
- # Restore the original entry so that we still have something.
- linecache.cache[filename] = entry # type: ignore[attr-defined]
- lines = get_lines()
- return cls._for_filename_and_lines(filename, tuple(lines))
- @classmethod
- def _for_filename_and_lines(cls, filename, lines):
- # type: (str, Sequence[str]) -> "Source"
- source_cache = cls._class_local('__source_cache_with_lines', {}) # type: Dict[Tuple[str, Sequence[str]], Source]
- try:
- return source_cache[(filename, lines)]
- except KeyError:
- pass
- result = source_cache[(filename, lines)] = cls(filename, lines)
- return result
- @classmethod
- def lazycache(cls, frame):
- # type: (types.FrameType) -> None
- linecache.lazycache(frame.f_code.co_filename, frame.f_globals)
- @classmethod
- def executing(cls, frame_or_tb):
- # type: (Union[types.TracebackType, types.FrameType]) -> "Executing"
- """
- Returns an `Executing` object representing the operation
- currently executing in the given frame or traceback object.
- """
- if isinstance(frame_or_tb, types.TracebackType):
- # https://docs.python.org/3/reference/datamodel.html#traceback-objects
- # "tb_lineno gives the line number where the exception occurred;
- # tb_lasti indicates the precise instruction.
- # The line number and last instruction in the traceback may differ
- # from the line number of its frame object
- # if the exception occurred in a try statement with no matching except clause
- # or with a finally clause."
- tb = frame_or_tb
- frame = tb.tb_frame
- lineno = tb.tb_lineno
- lasti = tb.tb_lasti
- else:
- frame = frame_or_tb
- lineno = frame.f_lineno
- lasti = frame.f_lasti
- code = frame.f_code
- key = (code, id(code), lasti)
- executing_cache = cls._class_local('__executing_cache', {}) # type: Dict[Tuple[types.CodeType, int, int], Any]
- args = executing_cache.get(key)
- if not args:
- node = stmts = decorator = None
- source = cls.for_frame(frame)
- tree = source.tree
- if tree:
- try:
- stmts = source.statements_at_line(lineno)
- if stmts:
- if is_ipython_cell_code(code):
- decorator, node = find_node_ipython(frame, lasti, stmts, source)
- else:
- node_finder = NodeFinder(frame, stmts, tree, lasti, source)
- node = node_finder.result
- decorator = node_finder.decorator
- if node:
- new_stmts = {statement_containing_node(node)}
- assert_(new_stmts <= stmts)
- stmts = new_stmts
- except Exception:
- if TESTING:
- raise
- executing_cache[key] = args = source, node, stmts, decorator
- return Executing(frame, *args)
- @classmethod
- def _class_local(cls, name, default):
- # type: (str, T) -> T
- """
- Returns an attribute directly associated with this class
- (as opposed to subclasses), setting default if necessary
- """
- # classes have a mappingproxy preventing us from using setdefault
- result = cls.__dict__.get(name, default)
- setattr(cls, name, result)
- return result
- @cache
- def statements_at_line(self, lineno):
- # type: (int) -> Set[EnhancedAST]
- """
- Returns the statement nodes overlapping the given line.
- Returns at most one statement unless semicolons are present.
- If the `text` attribute is not valid python, meaning
- `tree` is None, returns an empty set.
- Otherwise, `Source.for_frame(frame).statements_at_line(frame.f_lineno)`
- should return at least one statement.
- """
- return {
- statement_containing_node(node)
- for node in
- self._nodes_by_line[lineno]
- }
- def asttext(self):
- # type: () -> ASTText
- """
- Returns an ASTText object for getting the source of specific AST nodes.
- See http://asttokens.readthedocs.io/en/latest/api-index.html
- """
- from asttokens import ASTText # must be installed separately
- if self._asttext is None:
- self._asttext = ASTText(self.text, tree=self.tree, filename=self.filename)
- return self._asttext
- def asttokens(self):
- # type: () -> ASTTokens
- """
- Returns an ASTTokens object for getting the source of specific AST nodes.
- See http://asttokens.readthedocs.io/en/latest/api-index.html
- """
- import asttokens # must be installed separately
- if self._asttokens is None:
- if hasattr(asttokens, 'ASTText'):
- self._asttokens = self.asttext().asttokens
- else: # pragma: no cover
- self._asttokens = asttokens.ASTTokens(self.text, tree=self.tree, filename=self.filename)
- return self._asttokens
- def _asttext_base(self):
- # type: () -> ASTTextBase
- import asttokens # must be installed separately
- if hasattr(asttokens, 'ASTText'):
- return self.asttext()
- else: # pragma: no cover
- return self.asttokens()
- @staticmethod
- def decode_source(source):
- # type: (Union[str, bytes]) -> str
- if isinstance(source, bytes):
- encoding = Source.detect_encoding(source)
- return source.decode(encoding)
- else:
- return source
- @staticmethod
- def detect_encoding(source):
- # type: (bytes) -> str
- return detect_encoding(io.BytesIO(source).readline)[0]
- def code_qualname(self, code):
- # type: (types.CodeType) -> str
- """
- Imitates the __qualname__ attribute of functions for code objects.
- Given:
- - A function `func`
- - A frame `frame` for an execution of `func`, meaning:
- `frame.f_code is func.__code__`
- `Source.for_frame(frame).code_qualname(frame.f_code)`
- will be equal to `func.__qualname__`*. Works for Python 2 as well,
- where of course no `__qualname__` attribute exists.
- Falls back to `code.co_name` if there is no appropriate qualname.
- Based on https://github.com/wbolster/qualname
- (* unless `func` is a lambda
- nested inside another lambda on the same line, in which case
- the outer lambda's qualname will be returned for the codes
- of both lambdas)
- """
- assert_(code.co_filename == self.filename)
- return self._qualnames.get((code.co_name, code.co_firstlineno), code.co_name)
- class Executing(object):
- """
- Information about the operation a frame is currently executing.
- Generally you will just want `node`, which is the AST node being executed,
- or None if it's unknown.
- If a decorator is currently being called, then:
- - `node` is a function or class definition
- - `decorator` is the expression in `node.decorator_list` being called
- - `statements == {node}`
- """
- def __init__(self, frame, source, node, stmts, decorator):
- # type: (types.FrameType, Source, EnhancedAST, Set[ast.stmt], Optional[EnhancedAST]) -> None
- self.frame = frame
- self.source = source
- self.node = node
- self.statements = stmts
- self.decorator = decorator
- def code_qualname(self):
- # type: () -> str
- return self.source.code_qualname(self.frame.f_code)
- def text(self):
- # type: () -> str
- return self.source._asttext_base().get_text(self.node)
- def text_range(self):
- # type: () -> Tuple[int, int]
- return self.source._asttext_base().get_text_range(self.node)
- class QualnameVisitor(ast.NodeVisitor):
- def __init__(self):
- # type: () -> None
- super(QualnameVisitor, self).__init__()
- self.stack = [] # type: List[str]
- self.qualnames = {} # type: Dict[Tuple[str, int], str]
- def add_qualname(self, node, name=None):
- # type: (ast.AST, Optional[str]) -> None
- name = name or node.name # type: ignore[attr-defined]
- self.stack.append(name)
- if getattr(node, 'decorator_list', ()):
- lineno = node.decorator_list[0].lineno # type: ignore[attr-defined]
- else:
- lineno = node.lineno # type: ignore[attr-defined]
- self.qualnames.setdefault((name, lineno), ".".join(self.stack))
- def visit_FunctionDef(self, node, name=None):
- # type: (ast.AST, Optional[str]) -> None
- assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)), node
- self.add_qualname(node, name)
- self.stack.append('<locals>')
- children = [] # type: Sequence[ast.AST]
- if isinstance(node, ast.Lambda):
- children = [node.body]
- else:
- children = node.body
- for child in children:
- self.visit(child)
- self.stack.pop()
- self.stack.pop()
- # Find lambdas in the function definition outside the body,
- # e.g. decorators or default arguments
- # Based on iter_child_nodes
- for field, child in ast.iter_fields(node):
- if field == 'body':
- continue
- if isinstance(child, ast.AST):
- self.visit(child)
- elif isinstance(child, list):
- for grandchild in child:
- if isinstance(grandchild, ast.AST):
- self.visit(grandchild)
- visit_AsyncFunctionDef = visit_FunctionDef
- def visit_Lambda(self, node):
- # type: (ast.AST) -> None
- assert isinstance(node, ast.Lambda)
- self.visit_FunctionDef(node, '<lambda>')
- def visit_ClassDef(self, node):
- # type: (ast.AST) -> None
- assert isinstance(node, ast.ClassDef)
- self.add_qualname(node)
- self.generic_visit(node)
- self.stack.pop()
- future_flags = sum(
- getattr(__future__, fname).compiler_flag for fname in __future__.all_feature_names
- )
- def compile_similar_to(source, matching_code):
- # type: (ast.Module, types.CodeType) -> Any
- return compile(
- source,
- matching_code.co_filename,
- 'exec',
- flags=future_flags & matching_code.co_flags,
- dont_inherit=True,
- )
- sentinel = 'io8urthglkjdghvljusketgIYRFYUVGHFRTBGVHKGF78678957647698'
- def is_rewritten_by_pytest(code):
- # type: (types.CodeType) -> bool
- return any(
- bc.opname != "LOAD_CONST" and isinstance(bc.argval,str) and bc.argval.startswith("@py")
- for bc in get_instructions(code)
- )
- class SentinelNodeFinder(object):
- result = None # type: EnhancedAST
- def __init__(self, frame, stmts, tree, lasti, source):
- # type: (types.FrameType, Set[EnhancedAST], ast.Module, int, Source) -> None
- assert_(stmts)
- self.frame = frame
- self.tree = tree
- self.code = code = frame.f_code
- self.is_pytest = is_rewritten_by_pytest(code)
- if self.is_pytest:
- self.ignore_linenos = frozenset(assert_linenos(tree))
- else:
- self.ignore_linenos = frozenset()
- self.decorator = None
- self.instruction = instruction = self.get_actual_current_instruction(lasti)
- op_name = instruction.opname
- extra_filter = lambda e: True # type: Callable[[Any], bool]
- ctx = type(None) # type: Type
- typ = type(None) # type: Type
- if op_name.startswith('CALL_'):
- typ = ast.Call
- elif op_name.startswith(('BINARY_SUBSCR', 'SLICE+')):
- typ = ast.Subscript
- ctx = ast.Load
- elif op_name.startswith('BINARY_'):
- typ = ast.BinOp
- op_type = dict(
- BINARY_POWER=ast.Pow,
- BINARY_MULTIPLY=ast.Mult,
- BINARY_MATRIX_MULTIPLY=getattr(ast, "MatMult", ()),
- BINARY_FLOOR_DIVIDE=ast.FloorDiv,
- BINARY_TRUE_DIVIDE=ast.Div,
- BINARY_MODULO=ast.Mod,
- BINARY_ADD=ast.Add,
- BINARY_SUBTRACT=ast.Sub,
- BINARY_LSHIFT=ast.LShift,
- BINARY_RSHIFT=ast.RShift,
- BINARY_AND=ast.BitAnd,
- BINARY_XOR=ast.BitXor,
- BINARY_OR=ast.BitOr,
- )[op_name]
- extra_filter = lambda e: isinstance(e.op, op_type)
- elif op_name.startswith('UNARY_'):
- typ = ast.UnaryOp
- op_type = dict(
- UNARY_POSITIVE=ast.UAdd,
- UNARY_NEGATIVE=ast.USub,
- UNARY_NOT=ast.Not,
- UNARY_INVERT=ast.Invert,
- )[op_name]
- extra_filter = lambda e: isinstance(e.op, op_type)
- elif op_name in ('LOAD_ATTR', 'LOAD_METHOD', 'LOOKUP_METHOD'):
- typ = ast.Attribute
- ctx = ast.Load
- extra_filter = lambda e:mangled_name(e) == instruction.argval
- elif op_name in ('LOAD_NAME', 'LOAD_GLOBAL', 'LOAD_FAST', 'LOAD_DEREF', 'LOAD_CLASSDEREF'):
- typ = ast.Name
- ctx = ast.Load
- extra_filter = lambda e:mangled_name(e) == instruction.argval
- elif op_name in ('COMPARE_OP', 'IS_OP', 'CONTAINS_OP'):
- typ = ast.Compare
- extra_filter = lambda e: len(e.ops) == 1
- elif op_name.startswith(('STORE_SLICE', 'STORE_SUBSCR')):
- ctx = ast.Store
- typ = ast.Subscript
- elif op_name.startswith('STORE_ATTR'):
- ctx = ast.Store
- typ = ast.Attribute
- extra_filter = lambda e:mangled_name(e) == instruction.argval
- else:
- raise RuntimeError(op_name)
- with lock:
- exprs = {
- cast(EnhancedAST, node)
- for stmt in stmts
- for node in ast.walk(stmt)
- if isinstance(node, typ)
- if isinstance(getattr(node, "ctx", None), ctx)
- if extra_filter(node)
- if statement_containing_node(node) == stmt
- }
- if ctx == ast.Store:
- # No special bytecode tricks here.
- # We can handle multiple assigned attributes with different names,
- # but only one assigned subscript.
- self.result = only(exprs)
- return
- matching = list(self.matching_nodes(exprs))
- if not matching and typ == ast.Call:
- self.find_decorator(stmts)
- else:
- self.result = only(matching)
- def find_decorator(self, stmts):
- # type: (Union[List[EnhancedAST], Set[EnhancedAST]]) -> None
- stmt = only(stmts)
- assert_(isinstance(stmt, (ast.ClassDef, function_node_types)))
- decorators = stmt.decorator_list # type: ignore[attr-defined]
- assert_(decorators)
- line_instructions = [
- inst
- for inst in self.clean_instructions(self.code)
- if inst.lineno == self.frame.f_lineno
- ]
- last_decorator_instruction_index = [
- i
- for i, inst in enumerate(line_instructions)
- if inst.opname == "CALL_FUNCTION"
- ][-1]
- assert_(
- line_instructions[last_decorator_instruction_index + 1].opname.startswith(
- "STORE_"
- )
- )
- decorator_instructions = line_instructions[
- last_decorator_instruction_index
- - len(decorators)
- + 1 : last_decorator_instruction_index
- + 1
- ]
- assert_({inst.opname for inst in decorator_instructions} == {"CALL_FUNCTION"})
- decorator_index = decorator_instructions.index(self.instruction)
- decorator = decorators[::-1][decorator_index]
- self.decorator = decorator
- self.result = stmt
- def clean_instructions(self, code):
- # type: (types.CodeType) -> List[EnhancedInstruction]
- return [
- inst
- for inst in get_instructions(code)
- if inst.opname not in ("EXTENDED_ARG", "NOP")
- if inst.lineno not in self.ignore_linenos
- ]
- def get_original_clean_instructions(self):
- # type: () -> List[EnhancedInstruction]
- result = self.clean_instructions(self.code)
- # pypy sometimes (when is not clear)
- # inserts JUMP_IF_NOT_DEBUG instructions in bytecode
- # If they're not present in our compiled instructions,
- # ignore them in the original bytecode
- if not any(
- inst.opname == "JUMP_IF_NOT_DEBUG"
- for inst in self.compile_instructions()
- ):
- result = [
- inst for inst in result
- if inst.opname != "JUMP_IF_NOT_DEBUG"
- ]
- return result
- def matching_nodes(self, exprs):
- # type: (Set[EnhancedAST]) -> Iterator[EnhancedAST]
- original_instructions = self.get_original_clean_instructions()
- original_index = only(
- i
- for i, inst in enumerate(original_instructions)
- if inst == self.instruction
- )
- for expr_index, expr in enumerate(exprs):
- setter = get_setter(expr)
- assert setter is not None
- # noinspection PyArgumentList
- replacement = ast.BinOp(
- left=expr,
- op=ast.Pow(),
- right=ast.Str(s=sentinel),
- )
- ast.fix_missing_locations(replacement)
- setter(replacement)
- try:
- instructions = self.compile_instructions()
- finally:
- setter(expr)
- if sys.version_info >= (3, 10):
- try:
- handle_jumps(instructions, original_instructions)
- except Exception:
- # Give other candidates a chance
- if TESTING or expr_index < len(exprs) - 1:
- continue
- raise
- indices = [
- i
- for i, instruction in enumerate(instructions)
- if instruction.argval == sentinel
- ]
- # There can be several indices when the bytecode is duplicated,
- # as happens in a finally block in 3.9+
- # First we remove the opcodes caused by our modifications
- for index_num, sentinel_index in enumerate(indices):
- # Adjustment for removing sentinel instructions below
- # in past iterations
- sentinel_index -= index_num * 2
- assert_(instructions.pop(sentinel_index).opname == 'LOAD_CONST')
- assert_(instructions.pop(sentinel_index).opname == 'BINARY_POWER')
- # Then we see if any of the instruction indices match
- for index_num, sentinel_index in enumerate(indices):
- sentinel_index -= index_num * 2
- new_index = sentinel_index - 1
- if new_index != original_index:
- continue
- original_inst = original_instructions[original_index]
- new_inst = instructions[new_index]
- # In Python 3.9+, changing 'not x in y' to 'not sentinel_transformation(x in y)'
- # changes a CONTAINS_OP(invert=1) to CONTAINS_OP(invert=0),<sentinel stuff>,UNARY_NOT
- if (
- original_inst.opname == new_inst.opname in ('CONTAINS_OP', 'IS_OP')
- and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
- and (
- original_instructions[original_index + 1].opname
- != instructions[new_index + 1].opname == 'UNARY_NOT'
- )):
- # Remove the difference for the upcoming assert
- instructions.pop(new_index + 1)
- # Check that the modified instructions don't have anything unexpected
- # 3.10 is a bit too weird to assert this in all cases but things still work
- if sys.version_info < (3, 10):
- for inst1, inst2 in zip_longest(
- original_instructions, instructions
- ):
- assert_(inst1 and inst2 and opnames_match(inst1, inst2))
- yield expr
- def compile_instructions(self):
- # type: () -> List[EnhancedInstruction]
- module_code = compile_similar_to(self.tree, self.code)
- code = only(self.find_codes(module_code))
- return self.clean_instructions(code)
- def find_codes(self, root_code):
- # type: (types.CodeType) -> list
- checks = [
- attrgetter('co_firstlineno'),
- attrgetter('co_freevars'),
- attrgetter('co_cellvars'),
- lambda c: is_ipython_cell_code_name(c.co_name) or c.co_name,
- ] # type: List[Callable]
- if not self.is_pytest:
- checks += [
- attrgetter('co_names'),
- attrgetter('co_varnames'),
- ]
- def matches(c):
- # type: (types.CodeType) -> bool
- return all(
- f(c) == f(self.code)
- for f in checks
- )
- code_options = []
- if matches(root_code):
- code_options.append(root_code)
- def finder(code):
- # type: (types.CodeType) -> None
- for const in code.co_consts:
- if not inspect.iscode(const):
- continue
- if matches(const):
- code_options.append(const)
- finder(const)
- finder(root_code)
- return code_options
- def get_actual_current_instruction(self, lasti):
- # type: (int) -> EnhancedInstruction
- """
- Get the instruction corresponding to the current
- frame offset, skipping EXTENDED_ARG instructions
- """
- # Don't use get_original_clean_instructions
- # because we need the actual instructions including
- # EXTENDED_ARG
- instructions = list(get_instructions(self.code))
- index = only(
- i
- for i, inst in enumerate(instructions)
- if inst.offset == lasti
- )
- while True:
- instruction = instructions[index]
- if instruction.opname != "EXTENDED_ARG":
- return instruction
- index += 1
- def non_sentinel_instructions(instructions, start):
- # type: (List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction]]
- """
- Yields (index, instruction) pairs excluding the basic
- instructions introduced by the sentinel transformation
- """
- skip_power = False
- for i, inst in islice(enumerate(instructions), start, None):
- if inst.argval == sentinel:
- assert_(inst.opname == "LOAD_CONST")
- skip_power = True
- continue
- elif skip_power:
- assert_(inst.opname == "BINARY_POWER")
- skip_power = False
- continue
- yield i, inst
- def walk_both_instructions(original_instructions, original_start, instructions, start):
- # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction, int, EnhancedInstruction]]
- """
- Yields matching indices and instructions from the new and original instructions,
- leaving out changes made by the sentinel transformation.
- """
- original_iter = islice(enumerate(original_instructions), original_start, None)
- new_iter = non_sentinel_instructions(instructions, start)
- inverted_comparison = False
- while True:
- try:
- original_i, original_inst = next(original_iter)
- new_i, new_inst = next(new_iter)
- except StopIteration:
- return
- if (
- inverted_comparison
- and original_inst.opname != new_inst.opname == "UNARY_NOT"
- ):
- new_i, new_inst = next(new_iter)
- inverted_comparison = (
- original_inst.opname == new_inst.opname in ("CONTAINS_OP", "IS_OP")
- and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
- )
- yield original_i, original_inst, new_i, new_inst
- def handle_jumps(instructions, original_instructions):
- # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> None
- """
- Transforms instructions in place until it looks more like original_instructions.
- This is only needed in 3.10+ where optimisations lead to more drastic changes
- after the sentinel transformation.
- Replaces JUMP instructions that aren't also present in original_instructions
- with the sections that they jump to until a raise or return.
- In some other cases duplication found in `original_instructions`
- is replicated in `instructions`.
- """
- while True:
- for original_i, original_inst, new_i, new_inst in walk_both_instructions(
- original_instructions, 0, instructions, 0
- ):
- if opnames_match(original_inst, new_inst):
- continue
- if "JUMP" in new_inst.opname and "JUMP" not in original_inst.opname:
- # Find where the new instruction is jumping to, ignoring
- # instructions which have been copied in previous iterations
- start = only(
- i
- for i, inst in enumerate(instructions)
- if inst.offset == new_inst.argval
- and not getattr(inst, "_copied", False)
- )
- # Replace the jump instruction with the jumped to section of instructions
- # That section may also be deleted if it's not similarly duplicated
- # in original_instructions
- new_instructions = handle_jump(
- original_instructions, original_i, instructions, start
- )
- assert new_instructions is not None
- instructions[new_i : new_i + 1] = new_instructions
- else:
- # Extract a section of original_instructions from original_i to return/raise
- orig_section = []
- for section_inst in original_instructions[original_i:]:
- orig_section.append(section_inst)
- if section_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
- break
- else:
- # No return/raise - this is just a mismatch we can't handle
- raise AssertionError
- instructions[new_i:new_i] = only(find_new_matching(orig_section, instructions))
- # instructions has been modified, the for loop can't sensibly continue
- # Restart it from the beginning, checking for other issues
- break
- else: # No mismatched jumps found, we're done
- return
- def find_new_matching(orig_section, instructions):
- # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> Iterator[List[EnhancedInstruction]]
- """
- Yields sections of `instructions` which match `orig_section`.
- The yielded sections include sentinel instructions, but these
- are ignored when checking for matches.
- """
- for start in range(len(instructions) - len(orig_section)):
- indices, dup_section = zip(
- *islice(
- non_sentinel_instructions(instructions, start),
- len(orig_section),
- )
- )
- if len(dup_section) < len(orig_section):
- return
- if sections_match(orig_section, dup_section):
- yield instructions[start:indices[-1] + 1]
- def handle_jump(original_instructions, original_start, instructions, start):
- # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Optional[List[EnhancedInstruction]]
- """
- Returns the section of instructions starting at `start` and ending
- with a RETURN_VALUE or RAISE_VARARGS instruction.
- There should be a matching section in original_instructions starting at original_start.
- If that section doesn't appear elsewhere in original_instructions,
- then also delete the returned section of instructions.
- """
- for original_j, original_inst, new_j, new_inst in walk_both_instructions(
- original_instructions, original_start, instructions, start
- ):
- assert_(opnames_match(original_inst, new_inst))
- if original_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
- inlined = deepcopy(instructions[start : new_j + 1])
- for inl in inlined:
- inl._copied = True
- orig_section = original_instructions[original_start : original_j + 1]
- if not check_duplicates(
- original_start, orig_section, original_instructions
- ):
- instructions[start : new_j + 1] = []
- return inlined
-
- return None
- def check_duplicates(original_i, orig_section, original_instructions):
- # type: (int, List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
- """
- Returns True if a section of original_instructions starting somewhere other
- than original_i and matching orig_section is found, i.e. orig_section is duplicated.
- """
- for dup_start in range(len(original_instructions)):
- if dup_start == original_i:
- continue
- dup_section = original_instructions[dup_start : dup_start + len(orig_section)]
- if len(dup_section) < len(orig_section):
- return False
- if sections_match(orig_section, dup_section):
- return True
-
- return False
- def sections_match(orig_section, dup_section):
- # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
- """
- Returns True if the given lists of instructions have matching linenos and opnames.
- """
- return all(
- (
- orig_inst.lineno == dup_inst.lineno
- # POP_BLOCKs have been found to have differing linenos in innocent cases
- or "POP_BLOCK" == orig_inst.opname == dup_inst.opname
- )
- and opnames_match(orig_inst, dup_inst)
- for orig_inst, dup_inst in zip(orig_section, dup_section)
- )
- def opnames_match(inst1, inst2):
- # type: (Instruction, Instruction) -> bool
- return (
- inst1.opname == inst2.opname
- or "JUMP" in inst1.opname
- and "JUMP" in inst2.opname
- or (inst1.opname == "PRINT_EXPR" and inst2.opname == "POP_TOP")
- or (
- inst1.opname in ("LOAD_METHOD", "LOOKUP_METHOD")
- and inst2.opname == "LOAD_ATTR"
- )
- or (inst1.opname == "CALL_METHOD" and inst2.opname == "CALL_FUNCTION")
- )
- def get_setter(node):
- # type: (EnhancedAST) -> Optional[Callable[[ast.AST], None]]
- parent = node.parent
- for name, field in ast.iter_fields(parent):
- if field is node:
- def setter(new_node):
- # type: (ast.AST) -> None
- return setattr(parent, name, new_node)
- return setter
- elif isinstance(field, list):
- for i, item in enumerate(field):
- if item is node:
- def setter(new_node):
- # type: (ast.AST) -> None
- field[i] = new_node
- return setter
- return None
- lock = RLock()
- @cache
- def statement_containing_node(node):
- # type: (ast.AST) -> EnhancedAST
- while not isinstance(node, ast.stmt):
- node = cast(EnhancedAST, node).parent
- return cast(EnhancedAST, node)
- def assert_linenos(tree):
- # type: (ast.AST) -> Iterator[int]
- for node in ast.walk(tree):
- if (
- hasattr(node, 'parent') and
- isinstance(statement_containing_node(node), ast.Assert)
- ):
- for lineno in node_linenos(node):
- yield lineno
- def _extract_ipython_statement(stmt):
- # type: (EnhancedAST) -> ast.Module
- # IPython separates each statement in a cell to be executed separately
- # So NodeFinder should only compile one statement at a time or it
- # will find a code mismatch.
- while not isinstance(stmt.parent, ast.Module):
- stmt = stmt.parent
- # use `ast.parse` instead of `ast.Module` for better portability
- # python3.8 changes the signature of `ast.Module`
- # Inspired by https://github.com/pallets/werkzeug/pull/1552/files
- tree = ast.parse("")
- tree.body = [cast(ast.stmt, stmt)]
- ast.copy_location(tree, stmt)
- return tree
- def is_ipython_cell_code_name(code_name):
- # type: (str) -> bool
- return bool(re.match(r"(<module>|<cell line: \d+>)$", code_name))
- def is_ipython_cell_filename(filename):
- # type: (str) -> bool
- return bool(re.search(r"<ipython-input-|[/\\]ipykernel_\d+[/\\]", filename))
- def is_ipython_cell_code(code_obj):
- # type: (types.CodeType) -> bool
- return (
- is_ipython_cell_filename(code_obj.co_filename) and
- is_ipython_cell_code_name(code_obj.co_name)
- )
- def find_node_ipython(frame, lasti, stmts, source):
- # type: (types.FrameType, int, Set[EnhancedAST], Source) -> Tuple[Optional[Any], Optional[Any]]
- node = decorator = None
- for stmt in stmts:
- tree = _extract_ipython_statement(stmt)
- try:
- node_finder = NodeFinder(frame, stmts, tree, lasti, source)
- if (node or decorator) and (node_finder.result or node_finder.decorator):
- # Found potential nodes in separate statements,
- # cannot resolve ambiguity, give up here
- return None, None
- node = node_finder.result
- decorator = node_finder.decorator
- except Exception:
- pass
- return decorator, node
- def node_linenos(node):
- # type: (ast.AST) -> Iterator[int]
- if hasattr(node, "lineno"):
- linenos = [] # type: Sequence[int]
- if hasattr(node, "end_lineno") and isinstance(node, ast.expr):
- assert node.end_lineno is not None # type: ignore[attr-defined]
- linenos = range(node.lineno, node.end_lineno + 1) # type: ignore[attr-defined]
- else:
- linenos = [node.lineno] # type: ignore[attr-defined]
- for lineno in linenos:
- yield lineno
- if sys.version_info >= (3, 11):
- from ._position_node_finder import PositionNodeFinder as NodeFinder
- else:
- NodeFinder = SentinelNodeFinder
|