| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750 |
- """
- This module provides functionality for resuming Python execution at specific points in code,
- primarily used by PyTorch Dynamo for control flow handling and optimization. It implements
- bytecode transformation and execution state management to enable:
- - Resuming execution at arbitrary points in Python bytecode
- - Managing context managers and their state across execution boundaries
- - Transforming and generating new code objects with preserved execution state
- - Supporting Python 3.11+ exception handling and block management
- - Restoring torch function mode stacks and other execution context
- The module is critical for PyTorch Dynamo's ability to optimize code while preserving
- Python semantics and execution state.
- """
- import copy
- import dataclasses
- import sys
- import types
- from collections.abc import Callable, Iterable
- from contextlib import AbstractContextManager
- from typing import Any, cast, Optional
- from .bytecode_transformation import (
- add_push_null,
- bytecode_from_template,
- create_binary_subscr,
- create_call_function,
- create_call_function_ex,
- create_instruction,
- create_jump_absolute,
- create_load_const,
- Instruction,
- overwrite_instruction,
- transform_code_object,
- unique_id,
- )
- from .utils import ExactWeakKeyDictionary
- # taken from code.h in cpython
- CO_OPTIMIZED = 0x0001
- CO_NEWLOCALS = 0x0002
- CO_VARARGS = 0x0004
- CO_VARKEYWORDS = 0x0008
- CO_NESTED = 0x0010
- CO_GENERATOR = 0x0020
- CO_NOFREE = 0x0040
- CO_COROUTINE = 0x0080
- CO_ITERABLE_COROUTINE = 0x0100
- CO_ASYNC_GENERATOR = 0x0200
- # trace_rules.py import this constant for consistency
- TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in"
- IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue"
- # If is_resume - this codegen is for a resume function
- def _initial_push_null(insts: list[Instruction]) -> None:
- if sys.version_info >= (3, 11):
- insts.append(create_instruction("PUSH_NULL"))
- if sys.version_info < (3, 13):
- insts.append(create_instruction("SWAP", arg=2))
- # Generates bytecode from template and splits the code where LOAD_FAST dummy is present.
- def _bytecode_from_template_with_split(
- template: Callable[..., Any],
- stack_index: int,
- varname_map: Optional[dict[str, Any]] = None,
- ) -> tuple[list[Instruction], list[Instruction]]:
- template_code = bytecode_from_template(template, varname_map=varname_map)
- template_code.append(create_instruction("POP_TOP"))
- # adjust exception table entry depth
- for inst in template_code:
- if inst.exn_tab_entry:
- inst.exn_tab_entry.depth += stack_index
- # search for LOAD_FAST dummy and replace it with 2 NOPs (we can break up the bytecode between them)
- dummy_idx, dummy_inst = next(
- (
- (i, inst)
- for i, inst in enumerate(template_code)
- if inst.opname in ("LOAD_FAST", "LOAD_FAST_BORROW")
- and inst.argval == "dummy"
- ),
- (None, None),
- )
- assert dummy_idx is not None and dummy_inst is not None
- # replace LOAD_FAST dummy with first NOP marking exception area
- overwrite_instruction(dummy_inst, [create_instruction("NOP")])
- # POP_TOP follows LOAD_FAST dummy - replace with NOP marking end of exception area
- assert template_code[dummy_idx + 1].opname == "POP_TOP"
- overwrite_instruction(template_code[dummy_idx + 1], [create_instruction("NOP")])
- return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :]
- def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
- # NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
- # on torch._dynamo.utils.
- # pyrefly: ignore [unknown-name]
- global __import_torch_dot__dynamo_dot_utils
- try:
- dummy
- except: # noqa: E722, B001
- __import_torch_dot__dynamo_dot_utils.set_torch_function_mode_stack( # type: ignore[name-defined]
- stack_var_name
- )
- raise
- @dataclasses.dataclass(frozen=True)
- class ReenterWith:
- stack_index: int
- target_values: Optional[tuple[Any, ...]] = None
- def try_except_torch_function_mode(
- self, code_options: dict[str, Any], cleanup: list[Instruction]
- ) -> list[Instruction]:
- """
- Codegen based off of:
- try:
- (rest)
- except:
- (restore previous tf mode stack)
- raise
- """
- from .variables.torch_function import get_prev_stack_var_name
- setup_try_except, epilogue = _bytecode_from_template_with_split(
- _try_except_tf_mode_template,
- self.stack_index,
- varname_map={"stack_var_name": get_prev_stack_var_name()},
- )
- cleanup[:] = epilogue + cleanup
- return setup_try_except
- # If we do not want to destroy the stack, we can do the same thing as a
- # `SETUP_WITH` block, only that we store the context manager in a local_symbol
- def try_finally(
- self, code_options: dict[str, Any], cleanup: list[Instruction]
- ) -> list[Instruction]:
- """
- Codegen based off of:
- load args
- enter context
- try:
- (rest)
- finally:
- exit context
- """
- # NOTE: we assume that TOS is a context manager CLASS!
- # pyrefly: ignore [implicit-any]
- load_args = []
- if self.target_values:
- load_args = [create_load_const(val) for val in self.target_values]
- ctx_name = unique_id(f"___context_manager_{self.stack_index}")
- if ctx_name not in code_options["co_varnames"]:
- code_options["co_varnames"] += (ctx_name,)
- for name in ["__enter__", "__exit__"]:
- if name not in code_options["co_names"]:
- code_options["co_names"] += (name,)
- create_ctx: list[Instruction] = []
- _initial_push_null(create_ctx)
- create_ctx.extend(
- [
- *load_args,
- *create_call_function(len(load_args), False),
- create_instruction("STORE_FAST", argval=ctx_name),
- ]
- )
- def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
- ctx.__enter__()
- try:
- dummy
- finally:
- ctx.__exit__(None, None, None)
- setup_try_finally, epilogue = _bytecode_from_template_with_split(
- _template, self.stack_index, varname_map={"ctx": ctx_name}
- )
- cleanup[:] = epilogue + cleanup
- return create_ctx + setup_try_finally
- def __call__(
- self, code_options: dict[str, Any], cleanup: list[Instruction]
- ) -> tuple[list[Instruction], Optional[Instruction]]:
- """
- Codegen based off of:
- with ctx(args):
- (rest)
- """
- # NOTE: we assume that TOS is a context manager CLASS!
- # pyrefly: ignore [implicit-any]
- load_args = []
- if self.target_values:
- load_args = [create_load_const(val) for val in self.target_values]
- create_ctx: list[Instruction] = []
- # Do not push NULL in Python 3.14+ since the NULL should be on the symbolic stack.
- if sys.version_info < (3, 14):
- _initial_push_null(create_ctx)
- create_ctx.extend(
- [
- *load_args,
- *create_call_function(len(load_args), False),
- ]
- )
- def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
- with ctx:
- dummy
- setup_with, epilogue = _bytecode_from_template_with_split(
- _template, self.stack_index
- )
- cleanup[:] = epilogue + cleanup
- load_fast_ctx_inst = next(
- (
- inst
- for inst in setup_with
- if inst.opname in ("LOAD_FAST", "LOAD_FAST_BORROW")
- and inst.argval == "ctx"
- ),
- None,
- )
- assert load_fast_ctx_inst is not None
- # ctx already loaded on stack before the template - no need to LOAD_FAST
- overwrite_instruction(load_fast_ctx_inst, [create_instruction("NOP")])
- # 3.11+ only
- push_exc_info_gen = (
- inst for inst in epilogue if inst.opname == "PUSH_EXC_INFO"
- )
- push_exc_info_inst = next(push_exc_info_gen, None)
- # expect only 1 PUSH_EXC_INFO in epilogue
- assert next(push_exc_info_gen, None) is None
- return create_ctx + setup_with, push_exc_info_inst
- @dataclasses.dataclass
- class ResumeFunctionMetadata:
- code: types.CodeType
- instructions: list[Instruction] = dataclasses.field(default_factory=list)
- # Python 3.11+ fields
- # NOTE: Python 3.11 removed blocks, but for our purposes, a "block" consists
- # of instructions of all exception table entries that have the same target.
- # map from PUSH_EXC_INFO's in the prefix to original block target offset
- prefix_block_target_offset_remap: list[int] = dataclasses.field(
- default_factory=list
- )
- # per-offset map from new block target offsets to original block target offsets
- block_target_offset_remap: dict[tuple[int, int], dict[int, int]] = (
- dataclasses.field(default_factory=dict)
- )
- def _filter_iter(
- l1: Iterable[Any],
- l2: Iterable[Any],
- cond: Callable[[Any, Any], bool],
- ) -> list[Any]:
- """
- Two-pointer conditional filter.
- e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o)
- returns the instructions with offsets in sorted_offsets
- """
- it = iter(l2)
- res: list[Instruction] = []
- try:
- cur = next(it)
- for val in l1:
- if cond(val, cur):
- res.append(val)
- cur = next(it)
- except StopIteration:
- pass
- return res
- def _load_tuple_and_call(tup: tuple[Any, ...]) -> list[Instruction]:
- insts: list[Instruction] = []
- _initial_push_null(insts)
- insts.extend(create_load_const(val) for val in tup)
- insts.extend(create_call_function(len(tup), False))
- return insts
- class ContinueExecutionCache:
- cache = ExactWeakKeyDictionary()
- generated_code_metadata = ExactWeakKeyDictionary()
- @classmethod
- def lookup(
- cls, code: types.CodeType, lineno: int, init_offset: int, *key: Any
- ) -> types.CodeType:
- if code not in cls.cache:
- cls.cache[code] = {}
- key = tuple(key)
- if key not in cls.cache[code]:
- cls.cache[code][key] = cls.generate(code, lineno, init_offset, *key)
- return cls.cache[code][key]
- @classmethod
- def generate(
- cls,
- code: types.CodeType,
- lineno: int,
- init_offset: int,
- resume_offset: int,
- setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+
- nstack: int,
- argnames: tuple[str, ...],
- argnames_null: tuple[str, ...],
- setup_fns: tuple[ReenterWith, ...],
- handle_inactive_ctx: bool,
- stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...],
- argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...],
- null_idxes: tuple[int, ...],
- # mainly used to ensure distinct code objects per stack trace,
- # which prevents excessive recompilation of inner frames
- nested_code_objs: tuple[types.CodeType],
- # Are we currently graph breaking on an instruction that doesn't push
- # its result to the stack? If so, and we are not the leaf resume, then we need to pop
- # the result of calling the next resume function.
- pop_nested_resume_result: bool,
- ) -> types.CodeType:
- assert resume_offset is not None
- assert not (
- code.co_flags
- & (CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR)
- )
- assert code.co_flags & CO_OPTIMIZED
- if code in ContinueExecutionCache.generated_code_metadata:
- return cls.generate_based_on_original_code_object(
- code,
- lineno,
- init_offset,
- resume_offset,
- setup_fn_target_offsets,
- nstack,
- argnames,
- argnames_null,
- setup_fns,
- handle_inactive_ctx,
- stack_ctx_vars,
- argnames_ctx_vars,
- null_idxes,
- nested_code_objs,
- pop_nested_resume_result,
- )
- is_py311_plus = sys.version_info >= (3, 11)
- meta = ResumeFunctionMetadata(code)
- def update(
- instructions: list[Instruction], code_options: dict[str, Any]
- ) -> None:
- meta.instructions = copy.deepcopy(instructions)
- args = ["__nested_resume_fns", "__nested_frame_values"]
- args += [f"___stack{i}" for i in range(nstack)]
- args.extend(v for v in argnames if v not in args)
- freevars = tuple(code_options["co_cellvars"] or []) + tuple(
- code_options["co_freevars"] or []
- )
- freevars = tuple(sorted(freevars))
- code_options["co_name"] = (
- f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_{code_options['co_name']}_at_{lineno}"
- )
- if is_py311_plus:
- qualified_path = code_options["co_qualname"].rsplit(".", maxsplit=1)
- if len(qualified_path) == 1:
- code_options["co_qualname"] = code_options["co_name"]
- else:
- assert len(qualified_path) == 2
- module_name, co_name = qualified_path
- code_options["co_qualname"] = (
- f"{module_name}.{TORCH_DYNAMO_RESUME_IN_PREFIX}_{co_name}_at_{lineno}"
- )
- code_options["co_firstlineno"] = lineno
- code_options["co_cellvars"] = ()
- code_options["co_freevars"] = freevars
- code_options["co_argcount"] = len(args)
- code_options["co_posonlyargcount"] = 0
- code_options["co_kwonlyargcount"] = 0
- code_options["co_varnames"] = tuple(
- args
- + [v for v in argnames_null if v not in args]
- + [v for v in code_options["co_varnames"] if v not in args]
- + [IS_TRACING_RESUME_PROLOGUE_VARNAME]
- )
- code_options["co_flags"] = code_options["co_flags"] & ~(
- CO_VARARGS | CO_VARKEYWORDS
- )
- target = next(i for i in instructions if i.offset == resume_offset)
- prefix = []
- if is_py311_plus:
- if freevars:
- prefix.append(
- create_instruction("COPY_FREE_VARS", arg=len(freevars))
- )
- prefix.append(create_instruction("RESUME", arg=0))
- # Set is_tracing_resume_prologue to prevent graph breaks.
- # This doesn't really do anything at runtime, but dynamo will trace this
- # and will know that we're in a resume function prologue.
- prefix.extend(
- [
- create_instruction("LOAD_CONST", argval=True),
- create_instruction(
- "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
- ),
- ]
- )
- cleanup: list[Instruction] = []
- hooks = {fn.stack_index: fn for fn in setup_fns}
- hook_target_offsets = {
- fn.stack_index: setup_fn_target_offsets[i]
- for i, fn in enumerate(setup_fns)
- }
- offset_to_inst = {inst.offset: inst for inst in instructions}
- # map old hook targets to new targets generated by the hook
- # pyrefly: ignore [implicit-any]
- old_hook_target_remap = {}
- stack_i = 0
- null_i = 0
- stack_ctx_vars_d = dict(stack_ctx_vars) # type: ignore[var-annotated,arg-type]
- for i in range(nstack + len(null_idxes)):
- if null_i < len(null_idxes) and null_idxes[null_i] == i:
- prefix.append(create_instruction("PUSH_NULL"))
- null_i += 1
- else:
- prefix.append(
- create_instruction("LOAD_FAST", argval=f"___stack{stack_i}")
- )
- if handle_inactive_ctx and stack_i in stack_ctx_vars_d:
- # NOTE: we assume that current stack var is a context manager CLASS!
- # Load args for context variable and construct it
- prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[stack_i]))
- stack_i += 1
- if i in hooks:
- hook = hooks.pop(i)
- hook_insts, exn_target = hook(code_options, cleanup)
- prefix.extend(hook_insts)
- if is_py311_plus:
- hook_target_offset = hook_target_offsets.pop(i)
- old_hook_target = offset_to_inst[hook_target_offset]
- meta.prefix_block_target_offset_remap.append(hook_target_offset)
- old_hook_target_remap[old_hook_target] = exn_target
- if is_py311_plus:
- # reverse the mapping since targets of later/nested contexts are inserted
- # into the mapping later, but show up earlier in the prefix.
- meta.prefix_block_target_offset_remap = list(
- reversed(meta.prefix_block_target_offset_remap)
- )
- assert not hooks
- # NOTE: we assume that local var is a context manager CLASS!
- # initialize inactive context vars in argnames
- if handle_inactive_ctx:
- for name, vals in argnames_ctx_vars:
- prefix.append(create_instruction("LOAD_FAST", argval=name))
- prefix.extend(_load_tuple_and_call(vals))
- prefix.append(create_instruction("STORE_FAST", argval=name))
- # 3.12+: store NULL into variables that were NULL
- if argnames_null:
- assert sys.version_info >= (3, 12)
- for v in argnames_null:
- assert v not in args
- prefix.extend(
- [
- create_instruction("PUSH_NULL"),
- create_instruction("STORE_FAST", argval=v),
- ]
- )
- # Call nested resume function
- if nested_code_objs:
- prefix.extend(
- [
- # set up __nested_resume_fns[-1] call
- *add_push_null(
- [
- create_instruction(
- "LOAD_FAST", argval="__nested_resume_fns"
- ),
- create_instruction("LOAD_CONST", argval=-1),
- create_binary_subscr(),
- ]
- ),
- # del __nested_resume_fns[-1]
- create_instruction("LOAD_FAST", argval="__nested_resume_fns"),
- create_instruction("LOAD_CONST", argval=-1),
- create_instruction("DELETE_SUBSCR"),
- # load [__nested_resume_fns, __nested_frame_values]
- create_instruction("LOAD_FAST", argval="__nested_resume_fns"),
- create_instruction("LOAD_FAST", argval="__nested_frame_values"),
- create_instruction("BUILD_LIST", arg=2),
- # load __nested_frame_values[-1]
- create_instruction("LOAD_FAST", argval="__nested_frame_values"),
- create_instruction("LOAD_CONST", argval=-1),
- create_binary_subscr(),
- # create [
- # __nested_resume_fns,
- # __nested_frame_values,
- # *__nested_frame_values[-1],
- # ]
- create_instruction("LIST_EXTEND", arg=1),
- # del __nested_frame_values[-1]
- create_instruction("LOAD_FAST", argval="__nested_frame_values"),
- create_instruction("LOAD_CONST", argval=-1),
- create_instruction("DELETE_SUBSCR"),
- # delete __nested values
- create_instruction("DELETE_FAST", argval="__nested_resume_fns"),
- create_instruction(
- "DELETE_FAST", argval="__nested_frame_values"
- ),
- # Set is_tracing_resume_prologue back to allow graph breaks
- # in the nested resume
- create_instruction("LOAD_CONST", argval=False),
- create_instruction(
- "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
- ),
- # finish the call
- *create_call_function_ex(False, False),
- ]
- )
- if pop_nested_resume_result:
- # pop the result of calling the nested resume function
- prefix.append(create_instruction("POP_TOP"))
- else:
- # Set is_tracing_resume_prologue back to allow graph breaks after the jump
- prefix.extend(
- [
- create_instruction("LOAD_CONST", argval=False),
- create_instruction(
- "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
- ),
- ]
- )
- prefix.append(create_jump_absolute(target))
- # because the line number table monotonically increases from co_firstlineno
- # remove starts_line for any instructions before the graph break instruction
- # this will ensure the instructions after the break have the correct line numbers
- for inst in instructions:
- if inst.offset == target.offset:
- break
- inst.starts_line = None
- if sys.version_info >= (3, 11):
- inst.positions = None
- if cleanup:
- prefix.extend(cleanup)
- prefix.extend(cls.unreachable_codes(code_options))
- # remap original instructions' exception table entries
- if old_hook_target_remap:
- # pyrefly: ignore [unbound-name]
- assert is_py311_plus
- for inst in instructions:
- if (
- inst.exn_tab_entry
- and inst.exn_tab_entry.target in old_hook_target_remap
- ):
- inst.exn_tab_entry.target = old_hook_target_remap[ # type: ignore[assignment]
- inst.exn_tab_entry.target
- ]
- # TODO(jansel): add dead code elimination here
- instructions[:] = prefix + instructions
- new_code, _ = transform_code_object(code, update)
- ContinueExecutionCache.generated_code_metadata[new_code] = meta
- return new_code
- @staticmethod
- def unreachable_codes(code_options: dict[str, Any]) -> list[Instruction]:
- """Codegen a `raise None` to make analysis work for unreachable code"""
- return [
- create_load_const(None),
- create_instruction("RAISE_VARARGS", arg=1),
- ]
- @classmethod
- def generate_based_on_original_code_object(
- cls,
- code: types.CodeType,
- lineno: int,
- init_offset: int,
- resume_offset: int,
- setup_fn_target_offsets: tuple[int, ...],
- *args: Any,
- ) -> types.CodeType:
- """
- This handles the case of generating a resume into code generated
- to resume something else. We want to always generate starting
- from the original code object so that if control flow paths
- converge we only generated 1 resume function (rather than 2^n
- resume functions).
- """
- meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[
- code
- ]
- def find_orig_offset(cur_offset: int) -> int:
- orig_offset = -1
- def find_orig_offset_transform(
- instructions: list[Instruction], code_options: dict[str, Any]
- ) -> None:
- nonlocal orig_offset
- (target,) = (i for i in instructions if i.offset == cur_offset)
- # match the functions starting at the last instruction as we have added a prefix
- new_target_tuple = tuple(
- i2
- for i1, i2 in zip(
- reversed(instructions), reversed(meta.instructions)
- )
- if i1 is target
- )
- if not new_target_tuple:
- # Instruction with cur_offset in instructions was not found
- # in the original code - orig_offset left as -1.
- # Caller expected to handle this case.
- return
- assert len(new_target_tuple) == 1
- new_target = new_target_tuple[0]
- assert target.opcode == new_target.opcode
- assert new_target.offset is not None
- orig_offset = new_target.offset
- transform_code_object(code, find_orig_offset_transform)
- return orig_offset
- orig_init_offset = find_orig_offset(init_offset)
- # It is fine if the initial instruction is not found in the original code;
- # this means we graph broke in the prefix, which only happens with nested graph breaks.
- # We should not be running into ambiguous graph break issues here.
- orig_resume_offset = find_orig_offset(resume_offset)
- assert orig_resume_offset > -1, (
- "resume instruction not found in original code - this is a bug."
- )
- if sys.version_info >= (3, 11):
- # setup_fn_target_offsets currently contains the target offset of
- # each setup_fn, based on `code`. When we codegen the resume function
- # based on the original code object, `meta.code`, the offsets in
- # setup_fn_target_offsets must be based on `meta.code` instead.
- offset_key = (orig_init_offset, orig_resume_offset)
- # NOTE: we key by offset_key since the same resume function may graph
- # break in multiple places and we need different block_target_offset_remap's
- # for each graph break location. Keying by orig_resume_offset may not be enough
- # if 2 graph breaks on different initial offsets resume on the same instruction
- # (although this is rare and not tested anywhere).
- if offset_key not in meta.block_target_offset_remap:
- block_target_offset_remap = meta.block_target_offset_remap[
- offset_key
- # pyrefly: ignore [implicit-any]
- ] = {}
- def remap_block_offsets(
- instructions: list[Instruction], code_options: dict[str, Any]
- ) -> None:
- # NOTE: each prefix block generates exactly one PUSH_EXC_INFO,
- # so we can tell which block a prefix PUSH_EXC_INFO belongs to,
- # by counting. Then we can use meta.prefix_block_target_offset_remap
- # to determine where in the original code the PUSH_EXC_INFO offset
- # replaced.
- prefix_blocks: list[Instruction] = []
- for inst in instructions:
- # NOTE meta.prefix_block_target_offset_remap is based off of how we codegen'd
- # context managers at the prefix/prologue of the resume function. It is the same for
- # every graph break in the same resume function, so we do not need to recompute
- # for each graph break (unlike for meta.block_target_offset_remap)
- if len(prefix_blocks) == len(
- meta.prefix_block_target_offset_remap
- ):
- break
- if inst.opname == "PUSH_EXC_INFO":
- prefix_blocks.append(inst)
- # remap block target offsets for blocks generated in the resume prefix
- for inst, o in zip(
- prefix_blocks, meta.prefix_block_target_offset_remap
- ):
- block_target_offset_remap[cast(int, inst.offset)] = o
- # current bytecode targets are after the prefix PUSH_EXC_INFO's
- cur_start_offset = (
- cast(int, prefix_blocks[-1].offset) if prefix_blocks else -1
- )
- # get the remaining block target offsets of the current bytecode
- cur_inst_offsets = sorted(
- n for n in setup_fn_target_offsets if n > cur_start_offset
- )
- targets = _filter_iter(
- instructions, cur_inst_offsets, lambda inst, o: inst.offset == o
- )
- # The original code and resume code should have matching suffixes.
- # Match the post-prefix block target offsets of the current resume code
- # and the original code.
- orig_targets = reversed(
- _filter_iter(
- zip(reversed(instructions), reversed(meta.instructions)),
- reversed(targets),
- lambda v1, v2: v1[0] is v2,
- )
- )
- for orig, cur in zip(orig_targets, targets):
- block_target_offset_remap[cur.offset] = orig[1].offset
- transform_code_object(code, remap_block_offsets)
- # if offset_key or offset is not in setup_fn_target_offsets, it is an error
- # that needs to be fixed
- setup_fn_target_offsets = tuple(
- meta.block_target_offset_remap[offset_key][n]
- for n in setup_fn_target_offsets
- )
- return ContinueExecutionCache.lookup(
- meta.code,
- lineno,
- orig_init_offset,
- orig_resume_offset,
- setup_fn_target_offsets,
- *args,
- )
|