resume_execution.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750
  1. """
  2. This module provides functionality for resuming Python execution at specific points in code,
  3. primarily used by PyTorch Dynamo for control flow handling and optimization. It implements
  4. bytecode transformation and execution state management to enable:
  5. - Resuming execution at arbitrary points in Python bytecode
  6. - Managing context managers and their state across execution boundaries
  7. - Transforming and generating new code objects with preserved execution state
  8. - Supporting Python 3.11+ exception handling and block management
  9. - Restoring torch function mode stacks and other execution context
  10. The module is critical for PyTorch Dynamo's ability to optimize code while preserving
  11. Python semantics and execution state.
  12. """
  13. import copy
  14. import dataclasses
  15. import sys
  16. import types
  17. from collections.abc import Callable, Iterable
  18. from contextlib import AbstractContextManager
  19. from typing import Any, cast, Optional
  20. from .bytecode_transformation import (
  21. add_push_null,
  22. bytecode_from_template,
  23. create_binary_subscr,
  24. create_call_function,
  25. create_call_function_ex,
  26. create_instruction,
  27. create_jump_absolute,
  28. create_load_const,
  29. Instruction,
  30. overwrite_instruction,
  31. transform_code_object,
  32. unique_id,
  33. )
  34. from .utils import ExactWeakKeyDictionary
  35. # taken from code.h in cpython
  36. CO_OPTIMIZED = 0x0001
  37. CO_NEWLOCALS = 0x0002
  38. CO_VARARGS = 0x0004
  39. CO_VARKEYWORDS = 0x0008
  40. CO_NESTED = 0x0010
  41. CO_GENERATOR = 0x0020
  42. CO_NOFREE = 0x0040
  43. CO_COROUTINE = 0x0080
  44. CO_ITERABLE_COROUTINE = 0x0100
  45. CO_ASYNC_GENERATOR = 0x0200
  46. # trace_rules.py import this constant for consistency
  47. TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in"
  48. IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue"
  49. # If is_resume - this codegen is for a resume function
  50. def _initial_push_null(insts: list[Instruction]) -> None:
  51. if sys.version_info >= (3, 11):
  52. insts.append(create_instruction("PUSH_NULL"))
  53. if sys.version_info < (3, 13):
  54. insts.append(create_instruction("SWAP", arg=2))
  55. # Generates bytecode from template and splits the code where LOAD_FAST dummy is present.
  56. def _bytecode_from_template_with_split(
  57. template: Callable[..., Any],
  58. stack_index: int,
  59. varname_map: Optional[dict[str, Any]] = None,
  60. ) -> tuple[list[Instruction], list[Instruction]]:
  61. template_code = bytecode_from_template(template, varname_map=varname_map)
  62. template_code.append(create_instruction("POP_TOP"))
  63. # adjust exception table entry depth
  64. for inst in template_code:
  65. if inst.exn_tab_entry:
  66. inst.exn_tab_entry.depth += stack_index
  67. # search for LOAD_FAST dummy and replace it with 2 NOPs (we can break up the bytecode between them)
  68. dummy_idx, dummy_inst = next(
  69. (
  70. (i, inst)
  71. for i, inst in enumerate(template_code)
  72. if inst.opname in ("LOAD_FAST", "LOAD_FAST_BORROW")
  73. and inst.argval == "dummy"
  74. ),
  75. (None, None),
  76. )
  77. assert dummy_idx is not None and dummy_inst is not None
  78. # replace LOAD_FAST dummy with first NOP marking exception area
  79. overwrite_instruction(dummy_inst, [create_instruction("NOP")])
  80. # POP_TOP follows LOAD_FAST dummy - replace with NOP marking end of exception area
  81. assert template_code[dummy_idx + 1].opname == "POP_TOP"
  82. overwrite_instruction(template_code[dummy_idx + 1], [create_instruction("NOP")])
  83. return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :]
  84. def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
  85. # NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
  86. # on torch._dynamo.utils.
  87. # pyrefly: ignore [unknown-name]
  88. global __import_torch_dot__dynamo_dot_utils
  89. try:
  90. dummy
  91. except: # noqa: E722, B001
  92. __import_torch_dot__dynamo_dot_utils.set_torch_function_mode_stack( # type: ignore[name-defined]
  93. stack_var_name
  94. )
  95. raise
  96. @dataclasses.dataclass(frozen=True)
  97. class ReenterWith:
  98. stack_index: int
  99. target_values: Optional[tuple[Any, ...]] = None
  100. def try_except_torch_function_mode(
  101. self, code_options: dict[str, Any], cleanup: list[Instruction]
  102. ) -> list[Instruction]:
  103. """
  104. Codegen based off of:
  105. try:
  106. (rest)
  107. except:
  108. (restore previous tf mode stack)
  109. raise
  110. """
  111. from .variables.torch_function import get_prev_stack_var_name
  112. setup_try_except, epilogue = _bytecode_from_template_with_split(
  113. _try_except_tf_mode_template,
  114. self.stack_index,
  115. varname_map={"stack_var_name": get_prev_stack_var_name()},
  116. )
  117. cleanup[:] = epilogue + cleanup
  118. return setup_try_except
  119. # If we do not want to destroy the stack, we can do the same thing as a
  120. # `SETUP_WITH` block, only that we store the context manager in a local_symbol
  121. def try_finally(
  122. self, code_options: dict[str, Any], cleanup: list[Instruction]
  123. ) -> list[Instruction]:
  124. """
  125. Codegen based off of:
  126. load args
  127. enter context
  128. try:
  129. (rest)
  130. finally:
  131. exit context
  132. """
  133. # NOTE: we assume that TOS is a context manager CLASS!
  134. # pyrefly: ignore [implicit-any]
  135. load_args = []
  136. if self.target_values:
  137. load_args = [create_load_const(val) for val in self.target_values]
  138. ctx_name = unique_id(f"___context_manager_{self.stack_index}")
  139. if ctx_name not in code_options["co_varnames"]:
  140. code_options["co_varnames"] += (ctx_name,)
  141. for name in ["__enter__", "__exit__"]:
  142. if name not in code_options["co_names"]:
  143. code_options["co_names"] += (name,)
  144. create_ctx: list[Instruction] = []
  145. _initial_push_null(create_ctx)
  146. create_ctx.extend(
  147. [
  148. *load_args,
  149. *create_call_function(len(load_args), False),
  150. create_instruction("STORE_FAST", argval=ctx_name),
  151. ]
  152. )
  153. def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
  154. ctx.__enter__()
  155. try:
  156. dummy
  157. finally:
  158. ctx.__exit__(None, None, None)
  159. setup_try_finally, epilogue = _bytecode_from_template_with_split(
  160. _template, self.stack_index, varname_map={"ctx": ctx_name}
  161. )
  162. cleanup[:] = epilogue + cleanup
  163. return create_ctx + setup_try_finally
  164. def __call__(
  165. self, code_options: dict[str, Any], cleanup: list[Instruction]
  166. ) -> tuple[list[Instruction], Optional[Instruction]]:
  167. """
  168. Codegen based off of:
  169. with ctx(args):
  170. (rest)
  171. """
  172. # NOTE: we assume that TOS is a context manager CLASS!
  173. # pyrefly: ignore [implicit-any]
  174. load_args = []
  175. if self.target_values:
  176. load_args = [create_load_const(val) for val in self.target_values]
  177. create_ctx: list[Instruction] = []
  178. # Do not push NULL in Python 3.14+ since the NULL should be on the symbolic stack.
  179. if sys.version_info < (3, 14):
  180. _initial_push_null(create_ctx)
  181. create_ctx.extend(
  182. [
  183. *load_args,
  184. *create_call_function(len(load_args), False),
  185. ]
  186. )
  187. def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
  188. with ctx:
  189. dummy
  190. setup_with, epilogue = _bytecode_from_template_with_split(
  191. _template, self.stack_index
  192. )
  193. cleanup[:] = epilogue + cleanup
  194. load_fast_ctx_inst = next(
  195. (
  196. inst
  197. for inst in setup_with
  198. if inst.opname in ("LOAD_FAST", "LOAD_FAST_BORROW")
  199. and inst.argval == "ctx"
  200. ),
  201. None,
  202. )
  203. assert load_fast_ctx_inst is not None
  204. # ctx already loaded on stack before the template - no need to LOAD_FAST
  205. overwrite_instruction(load_fast_ctx_inst, [create_instruction("NOP")])
  206. # 3.11+ only
  207. push_exc_info_gen = (
  208. inst for inst in epilogue if inst.opname == "PUSH_EXC_INFO"
  209. )
  210. push_exc_info_inst = next(push_exc_info_gen, None)
  211. # expect only 1 PUSH_EXC_INFO in epilogue
  212. assert next(push_exc_info_gen, None) is None
  213. return create_ctx + setup_with, push_exc_info_inst
  214. @dataclasses.dataclass
  215. class ResumeFunctionMetadata:
  216. code: types.CodeType
  217. instructions: list[Instruction] = dataclasses.field(default_factory=list)
  218. # Python 3.11+ fields
  219. # NOTE: Python 3.11 removed blocks, but for our purposes, a "block" consists
  220. # of instructions of all exception table entries that have the same target.
  221. # map from PUSH_EXC_INFO's in the prefix to original block target offset
  222. prefix_block_target_offset_remap: list[int] = dataclasses.field(
  223. default_factory=list
  224. )
  225. # per-offset map from new block target offsets to original block target offsets
  226. block_target_offset_remap: dict[tuple[int, int], dict[int, int]] = (
  227. dataclasses.field(default_factory=dict)
  228. )
  229. def _filter_iter(
  230. l1: Iterable[Any],
  231. l2: Iterable[Any],
  232. cond: Callable[[Any, Any], bool],
  233. ) -> list[Any]:
  234. """
  235. Two-pointer conditional filter.
  236. e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o)
  237. returns the instructions with offsets in sorted_offsets
  238. """
  239. it = iter(l2)
  240. res: list[Instruction] = []
  241. try:
  242. cur = next(it)
  243. for val in l1:
  244. if cond(val, cur):
  245. res.append(val)
  246. cur = next(it)
  247. except StopIteration:
  248. pass
  249. return res
  250. def _load_tuple_and_call(tup: tuple[Any, ...]) -> list[Instruction]:
  251. insts: list[Instruction] = []
  252. _initial_push_null(insts)
  253. insts.extend(create_load_const(val) for val in tup)
  254. insts.extend(create_call_function(len(tup), False))
  255. return insts
  256. class ContinueExecutionCache:
  257. cache = ExactWeakKeyDictionary()
  258. generated_code_metadata = ExactWeakKeyDictionary()
  259. @classmethod
  260. def lookup(
  261. cls, code: types.CodeType, lineno: int, init_offset: int, *key: Any
  262. ) -> types.CodeType:
  263. if code not in cls.cache:
  264. cls.cache[code] = {}
  265. key = tuple(key)
  266. if key not in cls.cache[code]:
  267. cls.cache[code][key] = cls.generate(code, lineno, init_offset, *key)
  268. return cls.cache[code][key]
  269. @classmethod
  270. def generate(
  271. cls,
  272. code: types.CodeType,
  273. lineno: int,
  274. init_offset: int,
  275. resume_offset: int,
  276. setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+
  277. nstack: int,
  278. argnames: tuple[str, ...],
  279. argnames_null: tuple[str, ...],
  280. setup_fns: tuple[ReenterWith, ...],
  281. handle_inactive_ctx: bool,
  282. stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...],
  283. argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...],
  284. null_idxes: tuple[int, ...],
  285. # mainly used to ensure distinct code objects per stack trace,
  286. # which prevents excessive recompilation of inner frames
  287. nested_code_objs: tuple[types.CodeType],
  288. # Are we currently graph breaking on an instruction that doesn't push
  289. # its result to the stack? If so, and we are not the leaf resume, then we need to pop
  290. # the result of calling the next resume function.
  291. pop_nested_resume_result: bool,
  292. ) -> types.CodeType:
  293. assert resume_offset is not None
  294. assert not (
  295. code.co_flags
  296. & (CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR)
  297. )
  298. assert code.co_flags & CO_OPTIMIZED
  299. if code in ContinueExecutionCache.generated_code_metadata:
  300. return cls.generate_based_on_original_code_object(
  301. code,
  302. lineno,
  303. init_offset,
  304. resume_offset,
  305. setup_fn_target_offsets,
  306. nstack,
  307. argnames,
  308. argnames_null,
  309. setup_fns,
  310. handle_inactive_ctx,
  311. stack_ctx_vars,
  312. argnames_ctx_vars,
  313. null_idxes,
  314. nested_code_objs,
  315. pop_nested_resume_result,
  316. )
  317. is_py311_plus = sys.version_info >= (3, 11)
  318. meta = ResumeFunctionMetadata(code)
  319. def update(
  320. instructions: list[Instruction], code_options: dict[str, Any]
  321. ) -> None:
  322. meta.instructions = copy.deepcopy(instructions)
  323. args = ["__nested_resume_fns", "__nested_frame_values"]
  324. args += [f"___stack{i}" for i in range(nstack)]
  325. args.extend(v for v in argnames if v not in args)
  326. freevars = tuple(code_options["co_cellvars"] or []) + tuple(
  327. code_options["co_freevars"] or []
  328. )
  329. freevars = tuple(sorted(freevars))
  330. code_options["co_name"] = (
  331. f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_{code_options['co_name']}_at_{lineno}"
  332. )
  333. if is_py311_plus:
  334. qualified_path = code_options["co_qualname"].rsplit(".", maxsplit=1)
  335. if len(qualified_path) == 1:
  336. code_options["co_qualname"] = code_options["co_name"]
  337. else:
  338. assert len(qualified_path) == 2
  339. module_name, co_name = qualified_path
  340. code_options["co_qualname"] = (
  341. f"{module_name}.{TORCH_DYNAMO_RESUME_IN_PREFIX}_{co_name}_at_{lineno}"
  342. )
  343. code_options["co_firstlineno"] = lineno
  344. code_options["co_cellvars"] = ()
  345. code_options["co_freevars"] = freevars
  346. code_options["co_argcount"] = len(args)
  347. code_options["co_posonlyargcount"] = 0
  348. code_options["co_kwonlyargcount"] = 0
  349. code_options["co_varnames"] = tuple(
  350. args
  351. + [v for v in argnames_null if v not in args]
  352. + [v for v in code_options["co_varnames"] if v not in args]
  353. + [IS_TRACING_RESUME_PROLOGUE_VARNAME]
  354. )
  355. code_options["co_flags"] = code_options["co_flags"] & ~(
  356. CO_VARARGS | CO_VARKEYWORDS
  357. )
  358. target = next(i for i in instructions if i.offset == resume_offset)
  359. prefix = []
  360. if is_py311_plus:
  361. if freevars:
  362. prefix.append(
  363. create_instruction("COPY_FREE_VARS", arg=len(freevars))
  364. )
  365. prefix.append(create_instruction("RESUME", arg=0))
  366. # Set is_tracing_resume_prologue to prevent graph breaks.
  367. # This doesn't really do anything at runtime, but dynamo will trace this
  368. # and will know that we're in a resume function prologue.
  369. prefix.extend(
  370. [
  371. create_instruction("LOAD_CONST", argval=True),
  372. create_instruction(
  373. "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
  374. ),
  375. ]
  376. )
  377. cleanup: list[Instruction] = []
  378. hooks = {fn.stack_index: fn for fn in setup_fns}
  379. hook_target_offsets = {
  380. fn.stack_index: setup_fn_target_offsets[i]
  381. for i, fn in enumerate(setup_fns)
  382. }
  383. offset_to_inst = {inst.offset: inst for inst in instructions}
  384. # map old hook targets to new targets generated by the hook
  385. # pyrefly: ignore [implicit-any]
  386. old_hook_target_remap = {}
  387. stack_i = 0
  388. null_i = 0
  389. stack_ctx_vars_d = dict(stack_ctx_vars) # type: ignore[var-annotated,arg-type]
  390. for i in range(nstack + len(null_idxes)):
  391. if null_i < len(null_idxes) and null_idxes[null_i] == i:
  392. prefix.append(create_instruction("PUSH_NULL"))
  393. null_i += 1
  394. else:
  395. prefix.append(
  396. create_instruction("LOAD_FAST", argval=f"___stack{stack_i}")
  397. )
  398. if handle_inactive_ctx and stack_i in stack_ctx_vars_d:
  399. # NOTE: we assume that current stack var is a context manager CLASS!
  400. # Load args for context variable and construct it
  401. prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[stack_i]))
  402. stack_i += 1
  403. if i in hooks:
  404. hook = hooks.pop(i)
  405. hook_insts, exn_target = hook(code_options, cleanup)
  406. prefix.extend(hook_insts)
  407. if is_py311_plus:
  408. hook_target_offset = hook_target_offsets.pop(i)
  409. old_hook_target = offset_to_inst[hook_target_offset]
  410. meta.prefix_block_target_offset_remap.append(hook_target_offset)
  411. old_hook_target_remap[old_hook_target] = exn_target
  412. if is_py311_plus:
  413. # reverse the mapping since targets of later/nested contexts are inserted
  414. # into the mapping later, but show up earlier in the prefix.
  415. meta.prefix_block_target_offset_remap = list(
  416. reversed(meta.prefix_block_target_offset_remap)
  417. )
  418. assert not hooks
  419. # NOTE: we assume that local var is a context manager CLASS!
  420. # initialize inactive context vars in argnames
  421. if handle_inactive_ctx:
  422. for name, vals in argnames_ctx_vars:
  423. prefix.append(create_instruction("LOAD_FAST", argval=name))
  424. prefix.extend(_load_tuple_and_call(vals))
  425. prefix.append(create_instruction("STORE_FAST", argval=name))
  426. # 3.12+: store NULL into variables that were NULL
  427. if argnames_null:
  428. assert sys.version_info >= (3, 12)
  429. for v in argnames_null:
  430. assert v not in args
  431. prefix.extend(
  432. [
  433. create_instruction("PUSH_NULL"),
  434. create_instruction("STORE_FAST", argval=v),
  435. ]
  436. )
  437. # Call nested resume function
  438. if nested_code_objs:
  439. prefix.extend(
  440. [
  441. # set up __nested_resume_fns[-1] call
  442. *add_push_null(
  443. [
  444. create_instruction(
  445. "LOAD_FAST", argval="__nested_resume_fns"
  446. ),
  447. create_instruction("LOAD_CONST", argval=-1),
  448. create_binary_subscr(),
  449. ]
  450. ),
  451. # del __nested_resume_fns[-1]
  452. create_instruction("LOAD_FAST", argval="__nested_resume_fns"),
  453. create_instruction("LOAD_CONST", argval=-1),
  454. create_instruction("DELETE_SUBSCR"),
  455. # load [__nested_resume_fns, __nested_frame_values]
  456. create_instruction("LOAD_FAST", argval="__nested_resume_fns"),
  457. create_instruction("LOAD_FAST", argval="__nested_frame_values"),
  458. create_instruction("BUILD_LIST", arg=2),
  459. # load __nested_frame_values[-1]
  460. create_instruction("LOAD_FAST", argval="__nested_frame_values"),
  461. create_instruction("LOAD_CONST", argval=-1),
  462. create_binary_subscr(),
  463. # create [
  464. # __nested_resume_fns,
  465. # __nested_frame_values,
  466. # *__nested_frame_values[-1],
  467. # ]
  468. create_instruction("LIST_EXTEND", arg=1),
  469. # del __nested_frame_values[-1]
  470. create_instruction("LOAD_FAST", argval="__nested_frame_values"),
  471. create_instruction("LOAD_CONST", argval=-1),
  472. create_instruction("DELETE_SUBSCR"),
  473. # delete __nested values
  474. create_instruction("DELETE_FAST", argval="__nested_resume_fns"),
  475. create_instruction(
  476. "DELETE_FAST", argval="__nested_frame_values"
  477. ),
  478. # Set is_tracing_resume_prologue back to allow graph breaks
  479. # in the nested resume
  480. create_instruction("LOAD_CONST", argval=False),
  481. create_instruction(
  482. "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
  483. ),
  484. # finish the call
  485. *create_call_function_ex(False, False),
  486. ]
  487. )
  488. if pop_nested_resume_result:
  489. # pop the result of calling the nested resume function
  490. prefix.append(create_instruction("POP_TOP"))
  491. else:
  492. # Set is_tracing_resume_prologue back to allow graph breaks after the jump
  493. prefix.extend(
  494. [
  495. create_instruction("LOAD_CONST", argval=False),
  496. create_instruction(
  497. "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
  498. ),
  499. ]
  500. )
  501. prefix.append(create_jump_absolute(target))
  502. # because the line number table monotonically increases from co_firstlineno
  503. # remove starts_line for any instructions before the graph break instruction
  504. # this will ensure the instructions after the break have the correct line numbers
  505. for inst in instructions:
  506. if inst.offset == target.offset:
  507. break
  508. inst.starts_line = None
  509. if sys.version_info >= (3, 11):
  510. inst.positions = None
  511. if cleanup:
  512. prefix.extend(cleanup)
  513. prefix.extend(cls.unreachable_codes(code_options))
  514. # remap original instructions' exception table entries
  515. if old_hook_target_remap:
  516. # pyrefly: ignore [unbound-name]
  517. assert is_py311_plus
  518. for inst in instructions:
  519. if (
  520. inst.exn_tab_entry
  521. and inst.exn_tab_entry.target in old_hook_target_remap
  522. ):
  523. inst.exn_tab_entry.target = old_hook_target_remap[ # type: ignore[assignment]
  524. inst.exn_tab_entry.target
  525. ]
  526. # TODO(jansel): add dead code elimination here
  527. instructions[:] = prefix + instructions
  528. new_code, _ = transform_code_object(code, update)
  529. ContinueExecutionCache.generated_code_metadata[new_code] = meta
  530. return new_code
  531. @staticmethod
  532. def unreachable_codes(code_options: dict[str, Any]) -> list[Instruction]:
  533. """Codegen a `raise None` to make analysis work for unreachable code"""
  534. return [
  535. create_load_const(None),
  536. create_instruction("RAISE_VARARGS", arg=1),
  537. ]
  538. @classmethod
  539. def generate_based_on_original_code_object(
  540. cls,
  541. code: types.CodeType,
  542. lineno: int,
  543. init_offset: int,
  544. resume_offset: int,
  545. setup_fn_target_offsets: tuple[int, ...],
  546. *args: Any,
  547. ) -> types.CodeType:
  548. """
  549. This handles the case of generating a resume into code generated
  550. to resume something else. We want to always generate starting
  551. from the original code object so that if control flow paths
  552. converge we only generated 1 resume function (rather than 2^n
  553. resume functions).
  554. """
  555. meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[
  556. code
  557. ]
  558. def find_orig_offset(cur_offset: int) -> int:
  559. orig_offset = -1
  560. def find_orig_offset_transform(
  561. instructions: list[Instruction], code_options: dict[str, Any]
  562. ) -> None:
  563. nonlocal orig_offset
  564. (target,) = (i for i in instructions if i.offset == cur_offset)
  565. # match the functions starting at the last instruction as we have added a prefix
  566. new_target_tuple = tuple(
  567. i2
  568. for i1, i2 in zip(
  569. reversed(instructions), reversed(meta.instructions)
  570. )
  571. if i1 is target
  572. )
  573. if not new_target_tuple:
  574. # Instruction with cur_offset in instructions was not found
  575. # in the original code - orig_offset left as -1.
  576. # Caller expected to handle this case.
  577. return
  578. assert len(new_target_tuple) == 1
  579. new_target = new_target_tuple[0]
  580. assert target.opcode == new_target.opcode
  581. assert new_target.offset is not None
  582. orig_offset = new_target.offset
  583. transform_code_object(code, find_orig_offset_transform)
  584. return orig_offset
  585. orig_init_offset = find_orig_offset(init_offset)
  586. # It is fine if the initial instruction is not found in the original code;
  587. # this means we graph broke in the prefix, which only happens with nested graph breaks.
  588. # We should not be running into ambiguous graph break issues here.
  589. orig_resume_offset = find_orig_offset(resume_offset)
  590. assert orig_resume_offset > -1, (
  591. "resume instruction not found in original code - this is a bug."
  592. )
  593. if sys.version_info >= (3, 11):
  594. # setup_fn_target_offsets currently contains the target offset of
  595. # each setup_fn, based on `code`. When we codegen the resume function
  596. # based on the original code object, `meta.code`, the offsets in
  597. # setup_fn_target_offsets must be based on `meta.code` instead.
  598. offset_key = (orig_init_offset, orig_resume_offset)
  599. # NOTE: we key by offset_key since the same resume function may graph
  600. # break in multiple places and we need different block_target_offset_remap's
  601. # for each graph break location. Keying by orig_resume_offset may not be enough
  602. # if 2 graph breaks on different initial offsets resume on the same instruction
  603. # (although this is rare and not tested anywhere).
  604. if offset_key not in meta.block_target_offset_remap:
  605. block_target_offset_remap = meta.block_target_offset_remap[
  606. offset_key
  607. # pyrefly: ignore [implicit-any]
  608. ] = {}
  609. def remap_block_offsets(
  610. instructions: list[Instruction], code_options: dict[str, Any]
  611. ) -> None:
  612. # NOTE: each prefix block generates exactly one PUSH_EXC_INFO,
  613. # so we can tell which block a prefix PUSH_EXC_INFO belongs to,
  614. # by counting. Then we can use meta.prefix_block_target_offset_remap
  615. # to determine where in the original code the PUSH_EXC_INFO offset
  616. # replaced.
  617. prefix_blocks: list[Instruction] = []
  618. for inst in instructions:
  619. # NOTE meta.prefix_block_target_offset_remap is based off of how we codegen'd
  620. # context managers at the prefix/prologue of the resume function. It is the same for
  621. # every graph break in the same resume function, so we do not need to recompute
  622. # for each graph break (unlike for meta.block_target_offset_remap)
  623. if len(prefix_blocks) == len(
  624. meta.prefix_block_target_offset_remap
  625. ):
  626. break
  627. if inst.opname == "PUSH_EXC_INFO":
  628. prefix_blocks.append(inst)
  629. # remap block target offsets for blocks generated in the resume prefix
  630. for inst, o in zip(
  631. prefix_blocks, meta.prefix_block_target_offset_remap
  632. ):
  633. block_target_offset_remap[cast(int, inst.offset)] = o
  634. # current bytecode targets are after the prefix PUSH_EXC_INFO's
  635. cur_start_offset = (
  636. cast(int, prefix_blocks[-1].offset) if prefix_blocks else -1
  637. )
  638. # get the remaining block target offsets of the current bytecode
  639. cur_inst_offsets = sorted(
  640. n for n in setup_fn_target_offsets if n > cur_start_offset
  641. )
  642. targets = _filter_iter(
  643. instructions, cur_inst_offsets, lambda inst, o: inst.offset == o
  644. )
  645. # The original code and resume code should have matching suffixes.
  646. # Match the post-prefix block target offsets of the current resume code
  647. # and the original code.
  648. orig_targets = reversed(
  649. _filter_iter(
  650. zip(reversed(instructions), reversed(meta.instructions)),
  651. reversed(targets),
  652. lambda v1, v2: v1[0] is v2,
  653. )
  654. )
  655. for orig, cur in zip(orig_targets, targets):
  656. block_target_offset_remap[cur.offset] = orig[1].offset
  657. transform_code_object(code, remap_block_offsets)
  658. # if offset_key or offset is not in setup_fn_target_offsets, it is an error
  659. # that needs to be fixed
  660. setup_fn_target_offsets = tuple(
  661. meta.block_target_offset_remap[offset_key][n]
  662. for n in setup_fn_target_offsets
  663. )
  664. return ContinueExecutionCache.lookup(
  665. meta.code,
  666. lineno,
  667. orig_init_offset,
  668. orig_resume_offset,
  669. setup_fn_target_offsets,
  670. *args,
  671. )