executing.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108
  1. """
  2. MIT License
  3. Copyright (c) 2021 Alex Hall
  4. Permission is hereby granted, free of charge, to any person obtaining a copy
  5. of this software and associated documentation files (the "Software"), to deal
  6. in the Software without restriction, including without limitation the rights
  7. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  8. copies of the Software, and to permit persons to whom the Software is
  9. furnished to do so, subject to the following conditions:
  10. The above copyright notice and this permission notice shall be included in all
  11. copies or substantial portions of the Software.
  12. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  13. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  15. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  18. SOFTWARE.
  19. """
  20. import __future__
  21. import ast
  22. import dis
  23. import inspect
  24. import io
  25. import linecache
  26. import re
  27. import sys
  28. import types
  29. from collections import defaultdict
  30. from copy import deepcopy
  31. from functools import lru_cache
  32. from itertools import islice
  33. from itertools import zip_longest
  34. from operator import attrgetter
  35. from pathlib import Path
  36. from threading import RLock
  37. from tokenize import detect_encoding
  38. from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Sized, Tuple, Type, TypeVar, Union, cast
  39. from ._utils import mangled_name,assert_, EnhancedAST,EnhancedInstruction,Instruction,get_instructions
  40. if TYPE_CHECKING: # pragma: no cover
  41. from asttokens import ASTTokens, ASTText
  42. from asttokens.asttokens import ASTTextBase
  43. function_node_types = (ast.FunctionDef, ast.AsyncFunctionDef) # type: Tuple[Type, ...]
  44. cache = lru_cache(maxsize=None)
  45. TESTING = 0
  46. class NotOneValueFound(Exception):
  47. def __init__(self,msg,values=[]):
  48. # type: (str, Sequence) -> None
  49. self.values=values
  50. super(NotOneValueFound,self).__init__(msg)
  51. T = TypeVar('T')
  52. def only(it):
  53. # type: (Iterable[T]) -> T
  54. if isinstance(it, Sized):
  55. if len(it) != 1:
  56. raise NotOneValueFound('Expected one value, found %s' % len(it))
  57. # noinspection PyTypeChecker
  58. return list(it)[0]
  59. lst = tuple(islice(it, 2))
  60. if len(lst) == 0:
  61. raise NotOneValueFound('Expected one value, found 0')
  62. if len(lst) > 1:
  63. raise NotOneValueFound('Expected one value, found several',lst)
  64. return lst[0]
  65. class Source(object):
  66. """
  67. The source code of a single file and associated metadata.
  68. The main method of interest is the classmethod `executing(frame)`.
  69. If you want an instance of this class, don't construct it.
  70. Ideally use the classmethod `for_frame(frame)`.
  71. If you don't have a frame, use `for_filename(filename [, module_globals])`.
  72. These methods cache instances by filename, so at most one instance exists per filename.
  73. Attributes:
  74. - filename
  75. - text
  76. - lines
  77. - tree: AST parsed from text, or None if text is not valid Python
  78. All nodes in the tree have an extra `parent` attribute
  79. Other methods of interest:
  80. - statements_at_line
  81. - asttokens
  82. - code_qualname
  83. """
  84. def __init__(self, filename, lines):
  85. # type: (str, Sequence[str]) -> None
  86. """
  87. Don't call this constructor, see the class docstring.
  88. """
  89. self.filename = filename
  90. self.text = ''.join(lines)
  91. self.lines = [line.rstrip('\r\n') for line in lines]
  92. self._nodes_by_line = defaultdict(list)
  93. self.tree = None
  94. self._qualnames = {}
  95. self._asttokens = None # type: Optional[ASTTokens]
  96. self._asttext = None # type: Optional[ASTText]
  97. try:
  98. self.tree = ast.parse(self.text, filename=filename)
  99. except (SyntaxError, ValueError):
  100. pass
  101. else:
  102. for node in ast.walk(self.tree):
  103. for child in ast.iter_child_nodes(node):
  104. cast(EnhancedAST, child).parent = cast(EnhancedAST, node)
  105. for lineno in node_linenos(node):
  106. self._nodes_by_line[lineno].append(node)
  107. visitor = QualnameVisitor()
  108. visitor.visit(self.tree)
  109. self._qualnames = visitor.qualnames
  110. @classmethod
  111. def for_frame(cls, frame, use_cache=True):
  112. # type: (types.FrameType, bool) -> "Source"
  113. """
  114. Returns the `Source` object corresponding to the file the frame is executing in.
  115. """
  116. return cls.for_filename(frame.f_code.co_filename, frame.f_globals or {}, use_cache)
  117. @classmethod
  118. def for_filename(
  119. cls,
  120. filename,
  121. module_globals=None,
  122. use_cache=True, # noqa no longer used
  123. ):
  124. # type: (Union[str, Path], Optional[Dict[str, Any]], bool) -> "Source"
  125. if isinstance(filename, Path):
  126. filename = str(filename)
  127. def get_lines():
  128. # type: () -> List[str]
  129. return linecache.getlines(cast(str, filename), module_globals)
  130. # Save the current linecache entry, then ensure the cache is up to date.
  131. entry = linecache.cache.get(filename) # type: ignore[attr-defined]
  132. linecache.checkcache(filename)
  133. lines = get_lines()
  134. if entry is not None and not lines:
  135. # There was an entry, checkcache removed it, and nothing replaced it.
  136. # This means the file wasn't simply changed (because the `lines` wouldn't be empty)
  137. # but rather the file was found not to exist, probably because `filename` was fake.
  138. # Restore the original entry so that we still have something.
  139. linecache.cache[filename] = entry # type: ignore[attr-defined]
  140. lines = get_lines()
  141. return cls._for_filename_and_lines(filename, tuple(lines))
  142. @classmethod
  143. def _for_filename_and_lines(cls, filename, lines):
  144. # type: (str, Sequence[str]) -> "Source"
  145. source_cache = cls._class_local('__source_cache_with_lines', {}) # type: Dict[Tuple[str, Sequence[str]], Source]
  146. try:
  147. return source_cache[(filename, lines)]
  148. except KeyError:
  149. pass
  150. result = source_cache[(filename, lines)] = cls(filename, lines)
  151. return result
  152. @classmethod
  153. def lazycache(cls, frame):
  154. # type: (types.FrameType) -> None
  155. linecache.lazycache(frame.f_code.co_filename, frame.f_globals)
  156. @classmethod
  157. def executing(cls, frame_or_tb):
  158. # type: (Union[types.TracebackType, types.FrameType]) -> "Executing"
  159. """
  160. Returns an `Executing` object representing the operation
  161. currently executing in the given frame or traceback object.
  162. """
  163. if isinstance(frame_or_tb, types.TracebackType):
  164. # https://docs.python.org/3/reference/datamodel.html#traceback-objects
  165. # "tb_lineno gives the line number where the exception occurred;
  166. # tb_lasti indicates the precise instruction.
  167. # The line number and last instruction in the traceback may differ
  168. # from the line number of its frame object
  169. # if the exception occurred in a try statement with no matching except clause
  170. # or with a finally clause."
  171. tb = frame_or_tb
  172. frame = tb.tb_frame
  173. lineno = tb.tb_lineno
  174. lasti = tb.tb_lasti
  175. else:
  176. frame = frame_or_tb
  177. lineno = frame.f_lineno
  178. lasti = frame.f_lasti
  179. code = frame.f_code
  180. key = (code, id(code), lasti)
  181. executing_cache = cls._class_local('__executing_cache', {}) # type: Dict[Tuple[types.CodeType, int, int], Any]
  182. args = executing_cache.get(key)
  183. if not args:
  184. node = stmts = decorator = None
  185. source = cls.for_frame(frame)
  186. tree = source.tree
  187. if tree:
  188. try:
  189. stmts = source.statements_at_line(lineno)
  190. if stmts:
  191. if is_ipython_cell_code(code):
  192. decorator, node = find_node_ipython(frame, lasti, stmts, source)
  193. else:
  194. node_finder = NodeFinder(frame, stmts, tree, lasti, source)
  195. node = node_finder.result
  196. decorator = node_finder.decorator
  197. if node:
  198. new_stmts = {statement_containing_node(node)}
  199. assert_(new_stmts <= stmts)
  200. stmts = new_stmts
  201. except Exception:
  202. if TESTING:
  203. raise
  204. executing_cache[key] = args = source, node, stmts, decorator
  205. return Executing(frame, *args)
  206. @classmethod
  207. def _class_local(cls, name, default):
  208. # type: (str, T) -> T
  209. """
  210. Returns an attribute directly associated with this class
  211. (as opposed to subclasses), setting default if necessary
  212. """
  213. # classes have a mappingproxy preventing us from using setdefault
  214. result = cls.__dict__.get(name, default)
  215. setattr(cls, name, result)
  216. return result
  217. @cache
  218. def statements_at_line(self, lineno):
  219. # type: (int) -> Set[EnhancedAST]
  220. """
  221. Returns the statement nodes overlapping the given line.
  222. Returns at most one statement unless semicolons are present.
  223. If the `text` attribute is not valid python, meaning
  224. `tree` is None, returns an empty set.
  225. Otherwise, `Source.for_frame(frame).statements_at_line(frame.f_lineno)`
  226. should return at least one statement.
  227. """
  228. return {
  229. statement_containing_node(node)
  230. for node in
  231. self._nodes_by_line[lineno]
  232. }
  233. def asttext(self):
  234. # type: () -> ASTText
  235. """
  236. Returns an ASTText object for getting the source of specific AST nodes.
  237. See http://asttokens.readthedocs.io/en/latest/api-index.html
  238. """
  239. from asttokens import ASTText # must be installed separately
  240. if self._asttext is None:
  241. self._asttext = ASTText(self.text, tree=self.tree, filename=self.filename)
  242. return self._asttext
  243. def asttokens(self):
  244. # type: () -> ASTTokens
  245. """
  246. Returns an ASTTokens object for getting the source of specific AST nodes.
  247. See http://asttokens.readthedocs.io/en/latest/api-index.html
  248. """
  249. import asttokens # must be installed separately
  250. if self._asttokens is None:
  251. if hasattr(asttokens, 'ASTText'):
  252. self._asttokens = self.asttext().asttokens
  253. else: # pragma: no cover
  254. self._asttokens = asttokens.ASTTokens(self.text, tree=self.tree, filename=self.filename)
  255. return self._asttokens
  256. def _asttext_base(self):
  257. # type: () -> ASTTextBase
  258. import asttokens # must be installed separately
  259. if hasattr(asttokens, 'ASTText'):
  260. return self.asttext()
  261. else: # pragma: no cover
  262. return self.asttokens()
  263. @staticmethod
  264. def decode_source(source):
  265. # type: (Union[str, bytes]) -> str
  266. if isinstance(source, bytes):
  267. encoding = Source.detect_encoding(source)
  268. return source.decode(encoding)
  269. else:
  270. return source
  271. @staticmethod
  272. def detect_encoding(source):
  273. # type: (bytes) -> str
  274. return detect_encoding(io.BytesIO(source).readline)[0]
  275. def code_qualname(self, code):
  276. # type: (types.CodeType) -> str
  277. """
  278. Imitates the __qualname__ attribute of functions for code objects.
  279. Given:
  280. - A function `func`
  281. - A frame `frame` for an execution of `func`, meaning:
  282. `frame.f_code is func.__code__`
  283. `Source.for_frame(frame).code_qualname(frame.f_code)`
  284. will be equal to `func.__qualname__`*. Works for Python 2 as well,
  285. where of course no `__qualname__` attribute exists.
  286. Falls back to `code.co_name` if there is no appropriate qualname.
  287. Based on https://github.com/wbolster/qualname
  288. (* unless `func` is a lambda
  289. nested inside another lambda on the same line, in which case
  290. the outer lambda's qualname will be returned for the codes
  291. of both lambdas)
  292. """
  293. assert_(code.co_filename == self.filename)
  294. return self._qualnames.get((code.co_name, code.co_firstlineno), code.co_name)
  295. class Executing(object):
  296. """
  297. Information about the operation a frame is currently executing.
  298. Generally you will just want `node`, which is the AST node being executed,
  299. or None if it's unknown.
  300. If a decorator is currently being called, then:
  301. - `node` is a function or class definition
  302. - `decorator` is the expression in `node.decorator_list` being called
  303. - `statements == {node}`
  304. """
  305. def __init__(self, frame, source, node, stmts, decorator):
  306. # type: (types.FrameType, Source, EnhancedAST, Set[ast.stmt], Optional[EnhancedAST]) -> None
  307. self.frame = frame
  308. self.source = source
  309. self.node = node
  310. self.statements = stmts
  311. self.decorator = decorator
  312. def code_qualname(self):
  313. # type: () -> str
  314. return self.source.code_qualname(self.frame.f_code)
  315. def text(self):
  316. # type: () -> str
  317. return self.source._asttext_base().get_text(self.node)
  318. def text_range(self):
  319. # type: () -> Tuple[int, int]
  320. return self.source._asttext_base().get_text_range(self.node)
  321. class QualnameVisitor(ast.NodeVisitor):
  322. def __init__(self):
  323. # type: () -> None
  324. super(QualnameVisitor, self).__init__()
  325. self.stack = [] # type: List[str]
  326. self.qualnames = {} # type: Dict[Tuple[str, int], str]
  327. def add_qualname(self, node, name=None):
  328. # type: (ast.AST, Optional[str]) -> None
  329. name = name or node.name # type: ignore[attr-defined]
  330. self.stack.append(name)
  331. if getattr(node, 'decorator_list', ()):
  332. lineno = node.decorator_list[0].lineno # type: ignore[attr-defined]
  333. else:
  334. lineno = node.lineno # type: ignore[attr-defined]
  335. self.qualnames.setdefault((name, lineno), ".".join(self.stack))
  336. def visit_FunctionDef(self, node, name=None):
  337. # type: (ast.AST, Optional[str]) -> None
  338. assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)), node
  339. self.add_qualname(node, name)
  340. self.stack.append('<locals>')
  341. children = [] # type: Sequence[ast.AST]
  342. if isinstance(node, ast.Lambda):
  343. children = [node.body]
  344. else:
  345. children = node.body
  346. for child in children:
  347. self.visit(child)
  348. self.stack.pop()
  349. self.stack.pop()
  350. # Find lambdas in the function definition outside the body,
  351. # e.g. decorators or default arguments
  352. # Based on iter_child_nodes
  353. for field, child in ast.iter_fields(node):
  354. if field == 'body':
  355. continue
  356. if isinstance(child, ast.AST):
  357. self.visit(child)
  358. elif isinstance(child, list):
  359. for grandchild in child:
  360. if isinstance(grandchild, ast.AST):
  361. self.visit(grandchild)
  362. visit_AsyncFunctionDef = visit_FunctionDef
  363. def visit_Lambda(self, node):
  364. # type: (ast.AST) -> None
  365. assert isinstance(node, ast.Lambda)
  366. self.visit_FunctionDef(node, '<lambda>')
  367. def visit_ClassDef(self, node):
  368. # type: (ast.AST) -> None
  369. assert isinstance(node, ast.ClassDef)
  370. self.add_qualname(node)
  371. self.generic_visit(node)
  372. self.stack.pop()
  373. future_flags = sum(
  374. getattr(__future__, fname).compiler_flag for fname in __future__.all_feature_names
  375. )
  376. def compile_similar_to(source, matching_code):
  377. # type: (ast.Module, types.CodeType) -> Any
  378. return compile(
  379. source,
  380. matching_code.co_filename,
  381. 'exec',
  382. flags=future_flags & matching_code.co_flags,
  383. dont_inherit=True,
  384. )
  385. sentinel = 'io8urthglkjdghvljusketgIYRFYUVGHFRTBGVHKGF78678957647698'
  386. def is_rewritten_by_pytest(code):
  387. # type: (types.CodeType) -> bool
  388. return any(
  389. bc.opname != "LOAD_CONST" and isinstance(bc.argval,str) and bc.argval.startswith("@py")
  390. for bc in get_instructions(code)
  391. )
  392. class SentinelNodeFinder(object):
  393. result = None # type: EnhancedAST
  394. def __init__(self, frame, stmts, tree, lasti, source):
  395. # type: (types.FrameType, Set[EnhancedAST], ast.Module, int, Source) -> None
  396. assert_(stmts)
  397. self.frame = frame
  398. self.tree = tree
  399. self.code = code = frame.f_code
  400. self.is_pytest = is_rewritten_by_pytest(code)
  401. if self.is_pytest:
  402. self.ignore_linenos = frozenset(assert_linenos(tree))
  403. else:
  404. self.ignore_linenos = frozenset()
  405. self.decorator = None
  406. self.instruction = instruction = self.get_actual_current_instruction(lasti)
  407. op_name = instruction.opname
  408. extra_filter = lambda e: True # type: Callable[[Any], bool]
  409. ctx = type(None) # type: Type
  410. typ = type(None) # type: Type
  411. if op_name.startswith('CALL_'):
  412. typ = ast.Call
  413. elif op_name.startswith(('BINARY_SUBSCR', 'SLICE+')):
  414. typ = ast.Subscript
  415. ctx = ast.Load
  416. elif op_name.startswith('BINARY_'):
  417. typ = ast.BinOp
  418. op_type = dict(
  419. BINARY_POWER=ast.Pow,
  420. BINARY_MULTIPLY=ast.Mult,
  421. BINARY_MATRIX_MULTIPLY=getattr(ast, "MatMult", ()),
  422. BINARY_FLOOR_DIVIDE=ast.FloorDiv,
  423. BINARY_TRUE_DIVIDE=ast.Div,
  424. BINARY_MODULO=ast.Mod,
  425. BINARY_ADD=ast.Add,
  426. BINARY_SUBTRACT=ast.Sub,
  427. BINARY_LSHIFT=ast.LShift,
  428. BINARY_RSHIFT=ast.RShift,
  429. BINARY_AND=ast.BitAnd,
  430. BINARY_XOR=ast.BitXor,
  431. BINARY_OR=ast.BitOr,
  432. )[op_name]
  433. extra_filter = lambda e: isinstance(e.op, op_type)
  434. elif op_name.startswith('UNARY_'):
  435. typ = ast.UnaryOp
  436. op_type = dict(
  437. UNARY_POSITIVE=ast.UAdd,
  438. UNARY_NEGATIVE=ast.USub,
  439. UNARY_NOT=ast.Not,
  440. UNARY_INVERT=ast.Invert,
  441. )[op_name]
  442. extra_filter = lambda e: isinstance(e.op, op_type)
  443. elif op_name in ('LOAD_ATTR', 'LOAD_METHOD', 'LOOKUP_METHOD'):
  444. typ = ast.Attribute
  445. ctx = ast.Load
  446. extra_filter = lambda e:mangled_name(e) == instruction.argval
  447. elif op_name in ('LOAD_NAME', 'LOAD_GLOBAL', 'LOAD_FAST', 'LOAD_DEREF', 'LOAD_CLASSDEREF'):
  448. typ = ast.Name
  449. ctx = ast.Load
  450. extra_filter = lambda e:mangled_name(e) == instruction.argval
  451. elif op_name in ('COMPARE_OP', 'IS_OP', 'CONTAINS_OP'):
  452. typ = ast.Compare
  453. extra_filter = lambda e: len(e.ops) == 1
  454. elif op_name.startswith(('STORE_SLICE', 'STORE_SUBSCR')):
  455. ctx = ast.Store
  456. typ = ast.Subscript
  457. elif op_name.startswith('STORE_ATTR'):
  458. ctx = ast.Store
  459. typ = ast.Attribute
  460. extra_filter = lambda e:mangled_name(e) == instruction.argval
  461. else:
  462. raise RuntimeError(op_name)
  463. with lock:
  464. exprs = {
  465. cast(EnhancedAST, node)
  466. for stmt in stmts
  467. for node in ast.walk(stmt)
  468. if isinstance(node, typ)
  469. if isinstance(getattr(node, "ctx", None), ctx)
  470. if extra_filter(node)
  471. if statement_containing_node(node) == stmt
  472. }
  473. if ctx == ast.Store:
  474. # No special bytecode tricks here.
  475. # We can handle multiple assigned attributes with different names,
  476. # but only one assigned subscript.
  477. self.result = only(exprs)
  478. return
  479. matching = list(self.matching_nodes(exprs))
  480. if not matching and typ == ast.Call:
  481. self.find_decorator(stmts)
  482. else:
  483. self.result = only(matching)
  484. def find_decorator(self, stmts):
  485. # type: (Union[List[EnhancedAST], Set[EnhancedAST]]) -> None
  486. stmt = only(stmts)
  487. assert_(isinstance(stmt, (ast.ClassDef, function_node_types)))
  488. decorators = stmt.decorator_list # type: ignore[attr-defined]
  489. assert_(decorators)
  490. line_instructions = [
  491. inst
  492. for inst in self.clean_instructions(self.code)
  493. if inst.lineno == self.frame.f_lineno
  494. ]
  495. last_decorator_instruction_index = [
  496. i
  497. for i, inst in enumerate(line_instructions)
  498. if inst.opname == "CALL_FUNCTION"
  499. ][-1]
  500. assert_(
  501. line_instructions[last_decorator_instruction_index + 1].opname.startswith(
  502. "STORE_"
  503. )
  504. )
  505. decorator_instructions = line_instructions[
  506. last_decorator_instruction_index
  507. - len(decorators)
  508. + 1 : last_decorator_instruction_index
  509. + 1
  510. ]
  511. assert_({inst.opname for inst in decorator_instructions} == {"CALL_FUNCTION"})
  512. decorator_index = decorator_instructions.index(self.instruction)
  513. decorator = decorators[::-1][decorator_index]
  514. self.decorator = decorator
  515. self.result = stmt
  516. def clean_instructions(self, code):
  517. # type: (types.CodeType) -> List[EnhancedInstruction]
  518. return [
  519. inst
  520. for inst in get_instructions(code)
  521. if inst.opname not in ("EXTENDED_ARG", "NOP")
  522. if inst.lineno not in self.ignore_linenos
  523. ]
  524. def get_original_clean_instructions(self):
  525. # type: () -> List[EnhancedInstruction]
  526. result = self.clean_instructions(self.code)
  527. # pypy sometimes (when is not clear)
  528. # inserts JUMP_IF_NOT_DEBUG instructions in bytecode
  529. # If they're not present in our compiled instructions,
  530. # ignore them in the original bytecode
  531. if not any(
  532. inst.opname == "JUMP_IF_NOT_DEBUG"
  533. for inst in self.compile_instructions()
  534. ):
  535. result = [
  536. inst for inst in result
  537. if inst.opname != "JUMP_IF_NOT_DEBUG"
  538. ]
  539. return result
  540. def matching_nodes(self, exprs):
  541. # type: (Set[EnhancedAST]) -> Iterator[EnhancedAST]
  542. original_instructions = self.get_original_clean_instructions()
  543. original_index = only(
  544. i
  545. for i, inst in enumerate(original_instructions)
  546. if inst == self.instruction
  547. )
  548. for expr_index, expr in enumerate(exprs):
  549. setter = get_setter(expr)
  550. assert setter is not None
  551. # noinspection PyArgumentList
  552. replacement = ast.BinOp(
  553. left=expr,
  554. op=ast.Pow(),
  555. right=ast.Str(s=sentinel),
  556. )
  557. ast.fix_missing_locations(replacement)
  558. setter(replacement)
  559. try:
  560. instructions = self.compile_instructions()
  561. finally:
  562. setter(expr)
  563. if sys.version_info >= (3, 10):
  564. try:
  565. handle_jumps(instructions, original_instructions)
  566. except Exception:
  567. # Give other candidates a chance
  568. if TESTING or expr_index < len(exprs) - 1:
  569. continue
  570. raise
  571. indices = [
  572. i
  573. for i, instruction in enumerate(instructions)
  574. if instruction.argval == sentinel
  575. ]
  576. # There can be several indices when the bytecode is duplicated,
  577. # as happens in a finally block in 3.9+
  578. # First we remove the opcodes caused by our modifications
  579. for index_num, sentinel_index in enumerate(indices):
  580. # Adjustment for removing sentinel instructions below
  581. # in past iterations
  582. sentinel_index -= index_num * 2
  583. assert_(instructions.pop(sentinel_index).opname == 'LOAD_CONST')
  584. assert_(instructions.pop(sentinel_index).opname == 'BINARY_POWER')
  585. # Then we see if any of the instruction indices match
  586. for index_num, sentinel_index in enumerate(indices):
  587. sentinel_index -= index_num * 2
  588. new_index = sentinel_index - 1
  589. if new_index != original_index:
  590. continue
  591. original_inst = original_instructions[original_index]
  592. new_inst = instructions[new_index]
  593. # In Python 3.9+, changing 'not x in y' to 'not sentinel_transformation(x in y)'
  594. # changes a CONTAINS_OP(invert=1) to CONTAINS_OP(invert=0),<sentinel stuff>,UNARY_NOT
  595. if (
  596. original_inst.opname == new_inst.opname in ('CONTAINS_OP', 'IS_OP')
  597. and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
  598. and (
  599. original_instructions[original_index + 1].opname
  600. != instructions[new_index + 1].opname == 'UNARY_NOT'
  601. )):
  602. # Remove the difference for the upcoming assert
  603. instructions.pop(new_index + 1)
  604. # Check that the modified instructions don't have anything unexpected
  605. # 3.10 is a bit too weird to assert this in all cases but things still work
  606. if sys.version_info < (3, 10):
  607. for inst1, inst2 in zip_longest(
  608. original_instructions, instructions
  609. ):
  610. assert_(inst1 and inst2 and opnames_match(inst1, inst2))
  611. yield expr
  612. def compile_instructions(self):
  613. # type: () -> List[EnhancedInstruction]
  614. module_code = compile_similar_to(self.tree, self.code)
  615. code = only(self.find_codes(module_code))
  616. return self.clean_instructions(code)
  617. def find_codes(self, root_code):
  618. # type: (types.CodeType) -> list
  619. checks = [
  620. attrgetter('co_firstlineno'),
  621. attrgetter('co_freevars'),
  622. attrgetter('co_cellvars'),
  623. lambda c: is_ipython_cell_code_name(c.co_name) or c.co_name,
  624. ] # type: List[Callable]
  625. if not self.is_pytest:
  626. checks += [
  627. attrgetter('co_names'),
  628. attrgetter('co_varnames'),
  629. ]
  630. def matches(c):
  631. # type: (types.CodeType) -> bool
  632. return all(
  633. f(c) == f(self.code)
  634. for f in checks
  635. )
  636. code_options = []
  637. if matches(root_code):
  638. code_options.append(root_code)
  639. def finder(code):
  640. # type: (types.CodeType) -> None
  641. for const in code.co_consts:
  642. if not inspect.iscode(const):
  643. continue
  644. if matches(const):
  645. code_options.append(const)
  646. finder(const)
  647. finder(root_code)
  648. return code_options
  649. def get_actual_current_instruction(self, lasti):
  650. # type: (int) -> EnhancedInstruction
  651. """
  652. Get the instruction corresponding to the current
  653. frame offset, skipping EXTENDED_ARG instructions
  654. """
  655. # Don't use get_original_clean_instructions
  656. # because we need the actual instructions including
  657. # EXTENDED_ARG
  658. instructions = list(get_instructions(self.code))
  659. index = only(
  660. i
  661. for i, inst in enumerate(instructions)
  662. if inst.offset == lasti
  663. )
  664. while True:
  665. instruction = instructions[index]
  666. if instruction.opname != "EXTENDED_ARG":
  667. return instruction
  668. index += 1
  669. def non_sentinel_instructions(instructions, start):
  670. # type: (List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction]]
  671. """
  672. Yields (index, instruction) pairs excluding the basic
  673. instructions introduced by the sentinel transformation
  674. """
  675. skip_power = False
  676. for i, inst in islice(enumerate(instructions), start, None):
  677. if inst.argval == sentinel:
  678. assert_(inst.opname == "LOAD_CONST")
  679. skip_power = True
  680. continue
  681. elif skip_power:
  682. assert_(inst.opname == "BINARY_POWER")
  683. skip_power = False
  684. continue
  685. yield i, inst
  686. def walk_both_instructions(original_instructions, original_start, instructions, start):
  687. # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction, int, EnhancedInstruction]]
  688. """
  689. Yields matching indices and instructions from the new and original instructions,
  690. leaving out changes made by the sentinel transformation.
  691. """
  692. original_iter = islice(enumerate(original_instructions), original_start, None)
  693. new_iter = non_sentinel_instructions(instructions, start)
  694. inverted_comparison = False
  695. while True:
  696. try:
  697. original_i, original_inst = next(original_iter)
  698. new_i, new_inst = next(new_iter)
  699. except StopIteration:
  700. return
  701. if (
  702. inverted_comparison
  703. and original_inst.opname != new_inst.opname == "UNARY_NOT"
  704. ):
  705. new_i, new_inst = next(new_iter)
  706. inverted_comparison = (
  707. original_inst.opname == new_inst.opname in ("CONTAINS_OP", "IS_OP")
  708. and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
  709. )
  710. yield original_i, original_inst, new_i, new_inst
  711. def handle_jumps(instructions, original_instructions):
  712. # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> None
  713. """
  714. Transforms instructions in place until it looks more like original_instructions.
  715. This is only needed in 3.10+ where optimisations lead to more drastic changes
  716. after the sentinel transformation.
  717. Replaces JUMP instructions that aren't also present in original_instructions
  718. with the sections that they jump to until a raise or return.
  719. In some other cases duplication found in `original_instructions`
  720. is replicated in `instructions`.
  721. """
  722. while True:
  723. for original_i, original_inst, new_i, new_inst in walk_both_instructions(
  724. original_instructions, 0, instructions, 0
  725. ):
  726. if opnames_match(original_inst, new_inst):
  727. continue
  728. if "JUMP" in new_inst.opname and "JUMP" not in original_inst.opname:
  729. # Find where the new instruction is jumping to, ignoring
  730. # instructions which have been copied in previous iterations
  731. start = only(
  732. i
  733. for i, inst in enumerate(instructions)
  734. if inst.offset == new_inst.argval
  735. and not getattr(inst, "_copied", False)
  736. )
  737. # Replace the jump instruction with the jumped to section of instructions
  738. # That section may also be deleted if it's not similarly duplicated
  739. # in original_instructions
  740. new_instructions = handle_jump(
  741. original_instructions, original_i, instructions, start
  742. )
  743. assert new_instructions is not None
  744. instructions[new_i : new_i + 1] = new_instructions
  745. else:
  746. # Extract a section of original_instructions from original_i to return/raise
  747. orig_section = []
  748. for section_inst in original_instructions[original_i:]:
  749. orig_section.append(section_inst)
  750. if section_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
  751. break
  752. else:
  753. # No return/raise - this is just a mismatch we can't handle
  754. raise AssertionError
  755. instructions[new_i:new_i] = only(find_new_matching(orig_section, instructions))
  756. # instructions has been modified, the for loop can't sensibly continue
  757. # Restart it from the beginning, checking for other issues
  758. break
  759. else: # No mismatched jumps found, we're done
  760. return
  761. def find_new_matching(orig_section, instructions):
  762. # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> Iterator[List[EnhancedInstruction]]
  763. """
  764. Yields sections of `instructions` which match `orig_section`.
  765. The yielded sections include sentinel instructions, but these
  766. are ignored when checking for matches.
  767. """
  768. for start in range(len(instructions) - len(orig_section)):
  769. indices, dup_section = zip(
  770. *islice(
  771. non_sentinel_instructions(instructions, start),
  772. len(orig_section),
  773. )
  774. )
  775. if len(dup_section) < len(orig_section):
  776. return
  777. if sections_match(orig_section, dup_section):
  778. yield instructions[start:indices[-1] + 1]
  779. def handle_jump(original_instructions, original_start, instructions, start):
  780. # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Optional[List[EnhancedInstruction]]
  781. """
  782. Returns the section of instructions starting at `start` and ending
  783. with a RETURN_VALUE or RAISE_VARARGS instruction.
  784. There should be a matching section in original_instructions starting at original_start.
  785. If that section doesn't appear elsewhere in original_instructions,
  786. then also delete the returned section of instructions.
  787. """
  788. for original_j, original_inst, new_j, new_inst in walk_both_instructions(
  789. original_instructions, original_start, instructions, start
  790. ):
  791. assert_(opnames_match(original_inst, new_inst))
  792. if original_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
  793. inlined = deepcopy(instructions[start : new_j + 1])
  794. for inl in inlined:
  795. inl._copied = True
  796. orig_section = original_instructions[original_start : original_j + 1]
  797. if not check_duplicates(
  798. original_start, orig_section, original_instructions
  799. ):
  800. instructions[start : new_j + 1] = []
  801. return inlined
  802. return None
  803. def check_duplicates(original_i, orig_section, original_instructions):
  804. # type: (int, List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
  805. """
  806. Returns True if a section of original_instructions starting somewhere other
  807. than original_i and matching orig_section is found, i.e. orig_section is duplicated.
  808. """
  809. for dup_start in range(len(original_instructions)):
  810. if dup_start == original_i:
  811. continue
  812. dup_section = original_instructions[dup_start : dup_start + len(orig_section)]
  813. if len(dup_section) < len(orig_section):
  814. return False
  815. if sections_match(orig_section, dup_section):
  816. return True
  817. return False
  818. def sections_match(orig_section, dup_section):
  819. # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
  820. """
  821. Returns True if the given lists of instructions have matching linenos and opnames.
  822. """
  823. return all(
  824. (
  825. orig_inst.lineno == dup_inst.lineno
  826. # POP_BLOCKs have been found to have differing linenos in innocent cases
  827. or "POP_BLOCK" == orig_inst.opname == dup_inst.opname
  828. )
  829. and opnames_match(orig_inst, dup_inst)
  830. for orig_inst, dup_inst in zip(orig_section, dup_section)
  831. )
  832. def opnames_match(inst1, inst2):
  833. # type: (Instruction, Instruction) -> bool
  834. return (
  835. inst1.opname == inst2.opname
  836. or "JUMP" in inst1.opname
  837. and "JUMP" in inst2.opname
  838. or (inst1.opname == "PRINT_EXPR" and inst2.opname == "POP_TOP")
  839. or (
  840. inst1.opname in ("LOAD_METHOD", "LOOKUP_METHOD")
  841. and inst2.opname == "LOAD_ATTR"
  842. )
  843. or (inst1.opname == "CALL_METHOD" and inst2.opname == "CALL_FUNCTION")
  844. )
  845. def get_setter(node):
  846. # type: (EnhancedAST) -> Optional[Callable[[ast.AST], None]]
  847. parent = node.parent
  848. for name, field in ast.iter_fields(parent):
  849. if field is node:
  850. def setter(new_node):
  851. # type: (ast.AST) -> None
  852. return setattr(parent, name, new_node)
  853. return setter
  854. elif isinstance(field, list):
  855. for i, item in enumerate(field):
  856. if item is node:
  857. def setter(new_node):
  858. # type: (ast.AST) -> None
  859. field[i] = new_node
  860. return setter
  861. return None
  862. lock = RLock()
  863. @cache
  864. def statement_containing_node(node):
  865. # type: (ast.AST) -> EnhancedAST
  866. while not isinstance(node, ast.stmt):
  867. node = cast(EnhancedAST, node).parent
  868. return cast(EnhancedAST, node)
  869. def assert_linenos(tree):
  870. # type: (ast.AST) -> Iterator[int]
  871. for node in ast.walk(tree):
  872. if (
  873. hasattr(node, 'parent') and
  874. isinstance(statement_containing_node(node), ast.Assert)
  875. ):
  876. for lineno in node_linenos(node):
  877. yield lineno
  878. def _extract_ipython_statement(stmt):
  879. # type: (EnhancedAST) -> ast.Module
  880. # IPython separates each statement in a cell to be executed separately
  881. # So NodeFinder should only compile one statement at a time or it
  882. # will find a code mismatch.
  883. while not isinstance(stmt.parent, ast.Module):
  884. stmt = stmt.parent
  885. # use `ast.parse` instead of `ast.Module` for better portability
  886. # python3.8 changes the signature of `ast.Module`
  887. # Inspired by https://github.com/pallets/werkzeug/pull/1552/files
  888. tree = ast.parse("")
  889. tree.body = [cast(ast.stmt, stmt)]
  890. ast.copy_location(tree, stmt)
  891. return tree
  892. def is_ipython_cell_code_name(code_name):
  893. # type: (str) -> bool
  894. return bool(re.match(r"(<module>|<cell line: \d+>)$", code_name))
  895. def is_ipython_cell_filename(filename):
  896. # type: (str) -> bool
  897. return bool(re.search(r"<ipython-input-|[/\\]ipykernel_\d+[/\\]", filename))
  898. def is_ipython_cell_code(code_obj):
  899. # type: (types.CodeType) -> bool
  900. return (
  901. is_ipython_cell_filename(code_obj.co_filename) and
  902. is_ipython_cell_code_name(code_obj.co_name)
  903. )
  904. def find_node_ipython(frame, lasti, stmts, source):
  905. # type: (types.FrameType, int, Set[EnhancedAST], Source) -> Tuple[Optional[Any], Optional[Any]]
  906. node = decorator = None
  907. for stmt in stmts:
  908. tree = _extract_ipython_statement(stmt)
  909. try:
  910. node_finder = NodeFinder(frame, stmts, tree, lasti, source)
  911. if (node or decorator) and (node_finder.result or node_finder.decorator):
  912. # Found potential nodes in separate statements,
  913. # cannot resolve ambiguity, give up here
  914. return None, None
  915. node = node_finder.result
  916. decorator = node_finder.decorator
  917. except Exception:
  918. pass
  919. return decorator, node
  920. def node_linenos(node):
  921. # type: (ast.AST) -> Iterator[int]
  922. if hasattr(node, "lineno"):
  923. linenos = [] # type: Sequence[int]
  924. if hasattr(node, "end_lineno") and isinstance(node, ast.expr):
  925. assert node.end_lineno is not None # type: ignore[attr-defined]
  926. linenos = range(node.lineno, node.end_lineno + 1) # type: ignore[attr-defined]
  927. else:
  928. linenos = [node.lineno] # type: ignore[attr-defined]
  929. for lineno in linenos:
  930. yield lineno
  931. if sys.version_info >= (3, 11):
  932. from ._position_node_finder import PositionNodeFinder as NodeFinder
  933. else:
  934. NodeFinder = SentinelNodeFinder