| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374 |
- """
- Side effect tracking and management for TorchDynamo's compilation system.
- This module provides infrastructure for tracking and managing side effects that occur
- during symbolic execution, including:
- - Tracking mutations to objects, attributes, and variables
- - Managing context changes (cell variables, global namespace modifications)
- - Handling aliasing and object identity preservation
- - Managing stack frame state and local variable changes
- - Tracking function calls with side effects
- Key classes:
- - SideEffects: Main container for tracking all side effects during execution
- - MutableSideEffects: Specialization for mutable object tracking
- - AttributeMutation/ValueMutation: Track specific types of mutations
- - Various specialized side effect classes for different scenarios
- The side effect system ensures that mutations performed during symbolic execution
- are properly replayed during runtime, maintaining the correctness of compiled code
- while enabling optimizations where safe.
- """
- import collections
- import contextlib
- import inspect
- import textwrap
- import traceback
- import warnings
- import weakref
- from collections.abc import Generator, MutableMapping
- from types import CellType
- from typing import Any, Optional, TYPE_CHECKING
- import torch
- import torch.nn
- from torch._dynamo.variables.misc import AutogradFunctionContextVariable
- from . import config, graph_break_hints, utils, variables
- from .bytecode_transformation import (
- bytecode_from_template,
- create_call_function,
- create_call_method,
- create_instruction,
- )
- from .codegen import PyCodegen
- from .exc import collapse_resume_frames, get_stack_above_dynamo, unimplemented
- from .source import GlobalSource, LocalCellSource, Source, TempLocalSource
- from .utils import is_frozen_dataclass, nn_module_new, object_new
- from .variables.base import (
- AttributeMutation,
- AttributeMutationExisting,
- AttributeMutationNew,
- is_side_effect_safe,
- ValueMutationExisting,
- ValueMutationNew,
- VariableTracker,
- )
- from .variables.user_defined import FrozenDataClassVariable
- if TYPE_CHECKING:
- from torch._dynamo.output_graph import OutputGraph
- from torch._dynamo.symbolic_convert import InstructionTranslatorBase
- from torch._dynamo.variables.lists import ListVariable
- side_effects_log = torch._logging.getArtifactLogger(__name__, "side_effects")
- def _manual_dict_setitem(
- dict_from: dict[Any, Any], dict_to: dict[Any, Any], mro_index: int
- ) -> None:
- # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
- # to be careful because we don't want to trigger the user defined object
- # setitem or clear. The mro_index is used to find the dict/OrderedDict from
- # the class mro.
- dict_class = type(dict_to).__mro__[mro_index]
- dict_class.clear(dict_to) # type: ignore[attr-defined]
- for k, v in dict_from.items():
- dict_class.__setitem__(dict_to, k, v) # type: ignore[index]
- def _manual_list_update(list_from: list[Any], list_to: list[Any]) -> None:
- list.clear(list_to)
- list.extend(list_to, list_from)
- class SideEffects:
- """
- Maintain records of mutations and provide methods to apply them during code generation.
- Handles tracking and applying side effects during PyTorch Dynamo compilation,
- maintaining Python semantics by managing mutations, attribute modifications,
- and other side effects that occur during program execution.
- Key responsibilities:
- - Tracks mutations to Python objects, lists, and dictionaries that need to be
- applied after an FX graph is run.
- - Manages attribute modifications and deletions
- - Handles tensor hooks and backward pass state
- - Tracks cell variable mutations and global variable changes
- - Ensures correct ordering and application of side effects after graph execution
- This ensures that optimized code behaves identically to the original Python code with
- respect to object mutations and other side effects.
- """
- id_to_variable: dict[int, VariableTracker]
- store_attr_mutations: dict[VariableTracker, dict[str, VariableTracker]]
- keepalive: list[Any]
- # Maps variable tracker to list of user stacks (StackSummary objects, formatted lazily)
- mutation_user_stacks: dict[VariableTracker, list[traceback.StackSummary]]
- def __init__(
- self,
- output_graph: "OutputGraph",
- id_to_variable: Optional[dict[int, VariableTracker]] = None,
- store_attr_mutations: Optional[
- dict[VariableTracker, dict[str, VariableTracker]]
- ] = None,
- mutation_user_stacks: dict[VariableTracker, list[traceback.StackSummary]]
- | None = None,
- keepalive: Optional[list[Any]] = None,
- save_for_backward: Optional[
- list[tuple[AutogradFunctionContextVariable, list[VariableTracker]]]
- ] = None,
- tensor_hooks: Optional[
- dict[
- int,
- tuple[
- "variables.TensorVariable",
- VariableTracker,
- "variables.RemovableHandleVariable",
- str,
- ],
- ]
- ] = None,
- ) -> None:
- super().__init__()
- self.output_graph_weakref = weakref.ref(output_graph)
- self.id_to_variable = id_to_variable or {}
- self.store_attr_mutations = store_attr_mutations or {}
- self.mutation_user_stacks = mutation_user_stacks or {}
- self.keepalive = keepalive or []
- self.save_for_backward = save_for_backward or []
- self.tensor_hooks = tensor_hooks or {}
- # Used by MappingProxyVariable to graph break in case of any mutated
- # dict
- self._has_existing_dict_mutation = False
- # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph.
- # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd.
- self.ca_final_callbacks_var: Optional[ListVariable] = None
- # Tracks VariableTracker objects whose mutations can be skipped.
- # For normal mutated variables, Dynamo generates code to replay/reconstruct
- # the mutations after graph execution. However, variables in this set have
- # their mutations ignored - the mutations happen during
- # execution but don't need to be replayed in the generated code.
- # Used for temporary mutations in contexts like torch.func.functional_call,
- # where module parameters/buffers are modified but later restored.
- self.ignore_mutation_on_these_variables: set[VariableTracker] = set()
- def ignore_mutations_on(self, var: VariableTracker) -> None:
- """Mutations to this variable will be executed but not not tracked,
- typically used for temporary mutations that are later restored."""
- self.ignore_mutation_on_these_variables.add(var)
- def stop_ignoring_mutations_on(self, var: VariableTracker) -> None:
- """Remove a variable from the skip mutation set, restoring normal mutation tracking."""
- if var in self.ignore_mutation_on_these_variables:
- self.ignore_mutation_on_these_variables.remove(var)
- def _capture_user_stack(self, key: VariableTracker) -> None:
- """Capture the current user stack from the instruction translator."""
- if config.side_effect_replay_policy == "silent":
- return
- if key not in self.mutation_user_stacks:
- self.mutation_user_stacks[key] = []
- self.mutation_user_stacks[key].append(
- torch._guards.TracingContext.extract_stack()
- )
- def __eq__(self, other: object) -> bool:
- assert isinstance(other, SideEffects)
- # NB: do NOT test keepalive
- return (
- self.id_to_variable == other.id_to_variable
- and self.store_attr_mutations == other.store_attr_mutations
- and self.save_for_backward == other.save_for_backward
- and self.tensor_hooks == other.tensor_hooks
- )
- def diff(self, other: "SideEffects") -> Optional[str]:
- if self.id_to_variable != other.id_to_variable:
- sk_itv = self.id_to_variable.keys()
- ok_itv = other.id_to_variable.keys()
- if sk_itv != ok_itv:
- return f"id_to_variable keys: {sk_itv} != {ok_itv}"
- # Feel free to augment this with more fancy diffing logic
- # if needed for debugging
- return "id_to_variable: unknown diff"
- elif self.store_attr_mutations != other.store_attr_mutations:
- sk_sam = self.store_attr_mutations.keys()
- ok_sam = other.store_attr_mutations.keys()
- if sk_sam != ok_sam:
- return f"store_attr_mutations keys: {sk_sam} != {ok_sam}"
- return "store_attr_mutations: unknown diff"
- elif self.save_for_backward != other.save_for_backward:
- return "save_for_backward"
- elif self.tensor_hooks != other.tensor_hooks:
- return "tensor_hooks"
- else:
- return None
- def clone(self) -> "SideEffects":
- """Create a shallow copy"""
- ref = self.output_graph_weakref()
- assert ref is not None
- return self.__class__(
- output_graph=ref,
- id_to_variable=dict(self.id_to_variable),
- store_attr_mutations={
- k: dict(v) for k, v in self.store_attr_mutations.items()
- },
- mutation_user_stacks=self.mutation_user_stacks,
- keepalive=list(self.keepalive),
- save_for_backward=self.save_for_backward,
- tensor_hooks=self.tensor_hooks,
- )
- def __contains__(self, item: Any) -> bool:
- return id(item) in self.id_to_variable
- def __getitem__(self, item: Any) -> VariableTracker:
- return self.id_to_variable[id(item)]
- def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool:
- output_graph = self.output_graph_weakref()
- return bool(
- output_graph
- and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
- )
- def should_allow_side_effects_in_hop(self) -> bool:
- output_graph = self.output_graph_weakref()
- return bool(
- output_graph
- and output_graph.current_tx.output.current_tracer.allow_side_effects_in_hop
- )
- def is_reconstructing_generator(self) -> bool:
- output_graph = self.output_graph_weakref()
- return bool(
- output_graph
- and output_graph.current_tx.output.current_tracer.is_reconstructing_generator
- )
- def check_allowed_side_effect(self, item: VariableTracker) -> bool:
- from torch._dynamo.variables.misc import AutogradFunctionContextVariable
- # People do things like self.dim = dim inside autograd.Function.
- # These are benign.
- if isinstance(item, AutogradFunctionContextVariable):
- return True
- if self.should_allow_externally_visible_side_effects_in_subtracer():
- return True
- if self.should_allow_side_effects_in_hop():
- return True
- if self.is_reconstructing_generator():
- # This is missing the case where one mutates a tensor. See
- # test_generator.py::test_reconstruct_generator_tensor_mutation
- unimplemented(
- gb_type="Generator reconstruction with mutations",
- context=f"mutating object: {item}",
- explanation="Cannot reconstruct a generator with variable mutations. "
- "Dynamo needs to fully exhaust the generator, which may cause "
- "unintended variable modifications.",
- hints=[
- "Remove mutations from the generator.",
- *graph_break_hints.FUNDAMENTAL,
- ],
- )
- assert item.mutation_type is not None
- if not is_side_effect_safe(item.mutation_type):
- unimplemented(
- gb_type="HOP: Unsafe side effect",
- context=f"Attempted to mutate {item}",
- explanation="Mutating a variable from outside the scope of this HOP is not supported.",
- hints=[
- "If the HOP is activation checkpointing (torch.utils.checkpoint.checkpoint), this points to a "
- "side effect in forward method. Eager activation checkpointing replays that side-effect while "
- "recomputing the forward in the backward. If you are ok with side-effect not replayed in the "
- "backward, try setting `torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = True`",
- ],
- )
- return False
- def store_attr(
- self, item: VariableTracker, name: str, value: VariableTracker
- ) -> None:
- assert self.is_attribute_mutation(item)
- self.check_allowed_side_effect(item)
- if item not in self.store_attr_mutations:
- self.store_attr_mutations[item] = {}
- self.store_attr_mutations[item][name] = value
- # Capture user stack for this mutation
- self._capture_user_stack(item)
- def load_attr(
- self,
- item: VariableTracker,
- name: str,
- deleted_ok: bool = False,
- check: bool = False,
- ) -> VariableTracker:
- if check:
- assert self.is_attribute_mutation(item)
- result = self.store_attr_mutations[item][name]
- if not deleted_ok and isinstance(result, variables.DeletedVariable):
- unimplemented(
- gb_type="Attempted to read a deleted variable",
- context=f"item: {item}, name: {name}",
- explanation="",
- hints=[*graph_break_hints.USER_ERROR],
- )
- return result
- def store_cell(self, cellvar: VariableTracker, value: VariableTracker) -> None:
- if cellvar.is_immutable():
- unimplemented(
- gb_type="Write to immutable cell",
- context=f"cellvar: {cellvar}, value: {value}",
- explanation="Dynamo doesn't support writing to immutable/sourceless cell variables.",
- hints=[*graph_break_hints.DIFFICULT],
- )
- assert isinstance(cellvar, variables.CellVariable)
- assert isinstance(value, variables.VariableTracker)
- self.store_attr(cellvar, "cell_contents", value)
- def load_cell(self, cellvar: VariableTracker) -> VariableTracker:
- assert isinstance(cellvar, variables.CellVariable)
- if self.has_pending_mutation_of_attr(cellvar, "cell_contents"):
- return self.load_attr(cellvar, "cell_contents", check=False)
- if cellvar.pre_existing_contents:
- return cellvar.pre_existing_contents
- unimplemented(
- gb_type="Read uninitialized cell",
- context=str(cellvar),
- explanation="Attempted to read a cell variable that has not been populated yet.",
- hints=[*graph_break_hints.USER_ERROR],
- )
- def load_global(self, gvar: VariableTracker, name: str) -> VariableTracker:
- assert isinstance(gvar, variables.VariableTracker)
- return self.load_attr(gvar, name)
- def store_global(
- self, gvar: VariableTracker, name: str, value: VariableTracker
- ) -> None:
- assert isinstance(gvar, variables.VariableTracker)
- assert isinstance(value, variables.VariableTracker)
- self.store_attr(gvar, name, value)
- @staticmethod
- def cls_supports_mutation_side_effects(cls: type) -> bool:
- return inspect.getattr_static(cls, "__getattribute__", None) in (
- object.__getattribute__,
- dict.__getattribute__,
- set.__getattribute__,
- frozenset.__getattribute__,
- int.__getattribute__,
- str.__getattribute__,
- list.__getattribute__,
- tuple.__getattribute__,
- BaseException.__getattribute__,
- )
- def is_attribute_mutation(self, item: VariableTracker) -> bool:
- return isinstance(item.mutation_type, AttributeMutation)
- def has_pending_mutation(self, item: VariableTracker) -> bool:
- return self.is_attribute_mutation(item) and bool(
- self.store_attr_mutations.get(item)
- )
- def has_pending_mutation_of_attr(self, item: VariableTracker, name: str) -> bool:
- return self.is_attribute_mutation(
- item
- ) and name in self.store_attr_mutations.get(item, ())
- def is_modified(self, item: VariableTracker) -> bool:
- if item.is_immutable():
- return False
- if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)):
- return True
- if isinstance(item, variables.UserDefinedObjectVariable):
- # Checks if the underlying dict or tuple vt has been modified
- return item in self.store_attr_mutations or item.is_underlying_vt_modified(
- self
- )
- if self.is_attribute_mutation(item):
- return item in self.store_attr_mutations
- assert item.mutation_type is not None
- return item.mutation_type.is_modified # type: ignore[attr-defined]
- def _track_obj(
- self,
- item: Any,
- variable: VariableTracker,
- mutation_type_cls: type = ValueMutationExisting,
- ) -> VariableTracker:
- """Start tracking an existing or new variable for mutation"""
- if id(item) in self.id_to_variable:
- raise AssertionError(
- f"{variable} is already tracked for mutation. This could be "
- "because you are not using VariableBuilder to construct "
- "the variable tracker. "
- f"Source of new object: {variable.source}. "
- f"Source of previously tracked object: {self.id_to_variable[id(item)].source}."
- )
- variable.mutation_type = mutation_type_cls()
- self.id_to_variable[id(item)] = variable
- self.keepalive.append(item)
- return variable
- track_mutable = _track_obj
- def track_object_existing(
- self,
- item: Any,
- variable: VariableTracker,
- ) -> VariableTracker:
- # TODO: Modify this API so that we preserve type info of
- # variable
- return self._track_obj(
- item,
- variable,
- mutation_type_cls=AttributeMutationExisting,
- )
- def track_object_new(
- self,
- cls_source: Source | None,
- user_cls: Any,
- variable_cls: Any,
- options: dict[str, Any],
- ) -> VariableTracker:
- if user_cls is torch.autograd.function.FunctionCtx:
- with warnings.catch_warnings(record=True):
- obj = torch.autograd.Function()
- else:
- obj = object_new(user_cls)
- variable = variable_cls(
- obj,
- mutation_type=AttributeMutationNew(cls_source),
- **options,
- )
- self.id_to_variable[id(obj)] = variable
- self.keepalive.append(obj)
- return variable
- def get_variable_cls(self, user_cls: type) -> type:
- from torch.overrides import TorchFunctionMode
- from .variables.ctx_manager import GenericContextWrappingVariable
- from .variables.torch_function import TorchFunctionModeVariable
- from .variables.user_defined import is_forbidden_context_manager
- variable_cls: type[variables.UserDefinedObjectVariable] = (
- variables.UserDefinedObjectVariable
- )
- if issubclass(
- user_cls, TorchFunctionMode
- ) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls):
- variable_cls = TorchFunctionModeVariable
- elif (
- hasattr(user_cls, "__enter__")
- and hasattr(user_cls, "__exit__")
- and not is_forbidden_context_manager(user_cls)
- ):
- variable_cls = GenericContextWrappingVariable
- elif issubclass(user_cls, torch.nn.Module):
- variable_cls = variables.UnspecializedNNModuleVariable
- elif issubclass(user_cls, (dict, collections.OrderedDict)):
- variable_cls = variables.UserDefinedDictVariable
- elif issubclass(user_cls, (set, frozenset)):
- variable_cls = variables.UserDefinedSetVariable
- elif issubclass(user_cls, tuple):
- variable_cls = variables.UserDefinedTupleVariable
- elif issubclass(user_cls, list):
- variable_cls = variables.UserDefinedListVariable
- elif issubclass(user_cls, MutableMapping):
- variable_cls = variables.MutableMappingVariable
- elif is_frozen_dataclass(user_cls):
- variable_cls = FrozenDataClassVariable
- elif issubclass(user_cls, BaseException):
- variable_cls = variables.UserDefinedExceptionObjectVariable
- elif variables.InspectVariable.is_matching_class(user_cls):
- variable_cls = variables.InspectVariable
- assert issubclass(variable_cls, variables.UserDefinedObjectVariable)
- return variable_cls
- def get_example_value(
- self,
- base_cls_vt: VariableTracker,
- cls_vt: VariableTracker,
- init_args: list[VariableTracker],
- ) -> Any:
- user_cls = cls_vt.value # type: ignore[attr-defined]
- if issubclass(user_cls, torch.nn.Module):
- # TODO(anijain2305) - Is it possible to remove this specialization?
- obj = nn_module_new(user_cls)
- else:
- if isinstance(base_cls_vt, variables.BuiltinVariable):
- base_cls = base_cls_vt.fn
- elif isinstance(base_cls_vt, variables.UserDefinedClassVariable):
- base_cls = base_cls_vt.value
- else:
- raise RuntimeError(f"Unexpected base_cls_vt {base_cls_vt}")
- assert variables.UserDefinedClassVariable.is_supported_new_method(
- base_cls.__new__
- )
- # TODO(anijain2305) - Consider adding get_example_value method to
- # each VT to get an example value for all args. As we expand the
- # scope to other __new__ methods, we might need to call __new__ with
- # init_args (like functools.partial)
- # init_args = [arg.get_example_value() for arg in init_args]
- # obj = base_cls.__new__(user_cls, *init_args)
- obj = base_cls.__new__(user_cls)
- return obj
- def track_new_user_defined_object(
- self,
- base_cls_vt: VariableTracker,
- cls_vt: VariableTracker,
- init_args: list[VariableTracker],
- ) -> VariableTracker:
- """
- Creates a UserDefinedObjectVariable (or its subclass) variable tracker
- and mark it for attribute mutation tracking.
- Also records the variable trackers to call __new__ method on
- reconstruction. Roughly, the reconstruction looks like this
- base_cls_vt.__new__(user_cls, *init_args)
- """
- cls_source = cls_vt.source
- user_cls = cls_vt.value # type: ignore[attr-defined]
- variable_cls = self.get_variable_cls(user_cls)
- obj = self.get_example_value(base_cls_vt, cls_vt, init_args)
- variable = variable_cls(
- obj,
- cls_source=cls_vt.source,
- base_cls_vt=base_cls_vt,
- init_args=init_args,
- mutation_type=AttributeMutationNew(cls_source),
- )
- self.id_to_variable[id(obj)] = variable
- self.keepalive.append(obj)
- return variable
- def track_cell_new(
- self,
- ) -> VariableTracker:
- obj = object()
- variable = variables.CellVariable(
- mutation_type=AttributeMutationNew(),
- )
- self.id_to_variable[id(obj)] = variable
- self.keepalive.append(obj)
- return variable
- def track_cell_existing(
- self, source: Optional[Source], cell: CellType, contents: VariableTracker
- ) -> VariableTracker:
- variable = variables.CellVariable(
- # We don't support mutation to cell without source because we need
- # source to properly codegen the mutations.
- mutation_type=None if source is None else AttributeMutationExisting(),
- pre_existing_contents=contents,
- source=source,
- )
- self.id_to_variable[id(cell)] = variable
- self.keepalive.append(cell)
- return variable
- def track_global_existing(self, source: Source, item: Any) -> VariableTracker:
- variable = variables.NewGlobalVariable(
- mutation_type=AttributeMutationExisting(),
- source=source,
- )
- self.id_to_variable[id(item)] = variable
- self.keepalive.append(item)
- return variable
- def track_save_for_backward(
- self, ctx: VariableTracker, args: list[VariableTracker]
- ) -> None:
- assert isinstance(ctx, variables.AutogradFunctionContextVariable)
- self.save_for_backward.append((ctx, args))
- def track_runahead_tensor_and_symvar_side_effects(
- self, other: "SideEffects"
- ) -> None:
- # In higher order ops we want to keep track of tensors seen in the
- # speculate_subgraph so that we don't lift them again as a new input in
- # other speculate_subgraph or in the root tracer.
- for other_item in other.keepalive:
- other_id = id(other_item)
- other_variable = other.id_to_variable[other_id]
- if other_id not in self.id_to_variable and isinstance(
- other_variable, (variables.TensorVariable, variables.SymNodeVariable)
- ):
- self.track_object_existing(other_item, other_variable)
- def prune_dead_object_new(self, tx: "InstructionTranslatorBase") -> None:
- # Avoid VT cycles from e.g., recursive function.
- visited: set[VariableTracker] = set()
- live_new_objects: set[VariableTracker] = set()
- def visit(var: VariableTracker) -> None:
- if var in visited:
- return
- visited.add(var)
- # Object may have been mutated, store this mutation.
- if isinstance(var.mutation_type, AttributeMutationNew):
- live_new_objects.add(var)
- # It's possible that we have mutated the value of this variable
- # to be another one. The new value is in store_attr_mutations.
- # Also recurse through the new value to detect alive AttributeMutationNew.
- if var in self.store_attr_mutations:
- VariableTracker.visit(
- visit, # noqa: F821
- self.store_attr_mutations[var],
- )
- def is_live(var: VariableTracker) -> bool:
- if isinstance(var.mutation_type, AttributeMutationNew):
- return var in live_new_objects
- return True
- pre_existing_vars = [
- var
- for var in self.id_to_variable.values()
- if not isinstance(var.mutation_type, AttributeMutationNew)
- ]
- # The only live side effects come from returns (tx.stack), any intermediates
- # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables.
- # Recursively visit Variables and see if any of them have been mutated.
- init_live_vars = []
- # gather stack/symbolic_locals for all tx's up the chain
- cur_tx: Optional[InstructionTranslatorBase] = tx
- while cur_tx is not None:
- init_live_vars.extend([cur_tx.stack, cur_tx.symbolic_locals])
- if cur_tx.parent is not None:
- # for non-root tx'es, also keep the cells/freevars alive so they get codegen'd properly
- # TODO see if we could prune dead cells - cell pruning information needs to be forwarded
- # to the resume function creation as well.
- assert cur_tx.post_prune_cell_and_freevars is not None
- init_live_vars.append(cur_tx.post_prune_cell_and_freevars)
- cur_tx = cur_tx.parent
- VariableTracker.visit(
- visit,
- # TODO track from all possible sources.
- init_live_vars
- + [
- pre_existing_vars,
- tx.output.backward_state,
- self.tensor_hooks,
- ],
- )
- # Manually release the self-referential function, which indirectly
- # captures certain `VariableTracker` and affects parts of PT test/logic
- # that are sensitive to when certain objects get released.
- del visit
- # NB: cell variable handling.is tricky.
- # cell variables must stay alive if any NestedUserFunctionVariable
- # are live. "visit"-ing the NestedUserFunctionVariable visits
- # the .closures field, from which we will see if we need to keep
- # any mutations to cell variables alive.
- self.id_to_variable = {
- k: v for k, v in self.id_to_variable.items() if is_live(v)
- }
- self.store_attr_mutations = {
- k: v for k, v in self.store_attr_mutations.items() if is_live(k)
- }
- def mutation(self, var: VariableTracker) -> None:
- if var in self.ignore_mutation_on_these_variables:
- return
- self.check_allowed_side_effect(var)
- # Capture user stack for this mutation
- self._capture_user_stack(var)
- if isinstance(var.mutation_type, ValueMutationExisting):
- var.mutation_type.is_modified = True
- if (
- var.source
- and isinstance(var, variables.ConstDictVariable)
- and not isinstance(var, variables.SetVariable)
- ):
- self._has_existing_dict_mutation = True
- def has_existing_dict_mutation(self) -> bool:
- return self._has_existing_dict_mutation
- def _get_modified_vars(self) -> list[VariableTracker]:
- return [var for var in self.id_to_variable.values() if self.is_modified(var)]
- def codegen_save_tempvars(self, cg: PyCodegen) -> None:
- # We must codegen modified VT to their source by default, so that
- # mutation and aliasing are properly accounted for.
- #
- # Since newly constructed objects don't have a source, we manually
- # codegen their construction and store them to a newly assigned local
- # source. Note that `ValueMutationNew` isn't tracked by SideEffects.
- for var in self._get_modified_vars():
- if not isinstance(var.mutation_type, AttributeMutationNew):
- assert var.source is not None
- continue
- if isinstance(var, variables.CellVariable):
- # Cells created in the root frame are created either by
- # `MAKE_CELL` or by them being in `co_cellvars`, so we only emit
- # `make_cell` for the non-root-frame cells here.
- # TODO generalize this so we never need to call `make_cell`.
- if var.local_name is None:
- cg.add_push_null(
- lambda: cg.load_import_from(utils.__name__, "make_cell")
- )
- cg.extend_output(create_call_function(0, False))
- cg.add_cache(var)
- var.source = TempLocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
- elif var.source is None:
- var.source = LocalCellSource(var.local_name)
- elif var.is_tensor():
- # NOTE: for historical reasons we never assigned local sources
- # to newly constructed tensor object, so we keep it that way.
- # They are always loaded from output of the fx graph, so one can
- # think of it as having a "OutputGraphSource" for codegen
- # purposes.
- #
- # However, tensor subclass objects are different, because the
- # reconstruction logic in `PyCodegen` loads the data tensor from
- # graph output and then calls `as_subclass`, meaning we must
- # assign a source to it to ensure we only reconstruct one
- # subclass instance.
- if isinstance(
- var, variables.torch_function.TensorWithTFOverrideVariable
- ):
- # Don't codegen from temp source assigned from the 1st pass.
- cg(var, allow_cache=False)
- cg.add_cache(var)
- # `add_cache` generates STORE and consumes TOS, but we never
- # cleared it. TODO move this call into `add_cache`
- cg.clear_tos()
- var.source = TempLocalSource(cg.tempvars[var])
- elif isinstance(var, variables.AutogradFunctionContextVariable):
- unimplemented(
- gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region",
- context="",
- explanation="We cannot reconstruct a torch.autograd.Function's context object.",
- hints=[],
- )
- else:
- # Reconstruct the bytecode for
- # base_cls.__new__(user_cls, *args)
- if isinstance(var, variables.UserDefinedObjectVariable):
- def load_new_method() -> None:
- # pyrefly: ignore [missing-attribute]
- assert var.base_cls_vt is not None
- cg(var.base_cls_vt) # type: ignore[attr-defined]
- cg.extend_output([cg.create_load_attr("__new__")])
- cg.add_push_null(load_new_method)
- else:
- cg.add_push_null(
- lambda: cg.load_import_from(utils.__name__, "object_new")
- )
- assert var.mutation_type.cls_source is not None
- cg(var.mutation_type.cls_source)
- # Generate the args to the __new__ method
- for arg in var.init_args: # type: ignore[attr-defined]
- cg(arg)
- # Call the __new__ method
- cg.extend_output(create_call_function(1 + len(var.init_args), False)) # type: ignore[attr-defined]
- cg.add_cache(var)
- var.source = TempLocalSource(cg.tempvars[var])
- for ctx, args in self.save_for_backward:
- cg(ctx.source)
- cg.load_method("save_for_backward")
- for arg in args:
- cg(arg)
- cg.extend_output(
- [
- *create_call_method(len(args)),
- create_instruction("POP_TOP"),
- ]
- )
- def register_hook(
- self,
- tensor: "variables.TensorVariable",
- hook: VariableTracker,
- handle: "variables.RemovableHandleVariable",
- name: str,
- ) -> None:
- assert tensor.is_tensor()
- assert isinstance(hook, variables.VariableTracker)
- assert (
- isinstance(handle, variables.RemovableHandleVariable)
- and handle.is_mutable()
- )
- assert hasattr(torch.Tensor, name)
- idx = len(self.tensor_hooks.keys())
- # duplicate index possible because of self.remove_hook()
- while idx in self.tensor_hooks:
- idx += 1
- self.tensor_hooks[idx] = (tensor, hook, handle, name)
- assert not handle.idx
- handle.idx = idx
- def remove_hook(self, idx: int) -> None:
- del self.tensor_hooks[idx]
- def codegen_hooks(self, cg: PyCodegen) -> None:
- for (
- tensor,
- hook,
- handle,
- name,
- ) in self.tensor_hooks.values():
- # Note: [On tensor.register_hook]
- #
- # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented
- # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries).
- #
- # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph.
- # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in
- # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able
- # tensors. Because a source indicates knowledge of this object outside the torch compile region, and
- # because we are running residuals firmly before .backward() can be run, it is sound to invoke
- # `register_hook` on a known tensor.
- #
- # For tensors without a source, we support a limited subset of hooks. Global functions only, and
- # compiled_autograd must be enabled or we will graph break.
- #
- # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the
- # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed
- # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the
- # stack intact.
- #
- # Dynamo Tensor Hooks Workflow:
- # - Functions passed to register_hook are lifted globally.
- # - For tensors with sources:
- # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to:
- # - Generate the tensor.
- # - Issue a register_hook call on the tensor, linking to the globally stored function.
- # - Incorporate a handle if one was established in the eager phase.
- # - For tensors without sources:
- # - We don't generate any instructions for registering a hook.
- # - Handles from intermediary hooks are NYI.
- # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it.
- # - We then manually insert the call function above into the graph.
- # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST.
- assert tensor.source, "Hooks on non input tensors NYI - should not get here"
- def gen_fn() -> None:
- cg(tensor)
- cg.extend_output([cg.create_load_attr(name)])
- cg.add_push_null(gen_fn)
- cg(hook)
- cg.extend_output(create_call_function(1, False))
- # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will
- # be associated with the return value of register_hook(). This consumes the top of stack.
- cg.add_cache(handle)
- def get_ca_final_callbacks_var(self) -> "variables.ListVariable":
- from .variables.base import ValueMutationNew
- if self.ca_final_callbacks_var is None:
- self.ca_final_callbacks_var = variables.ListVariable(
- [], mutation_type=ValueMutationNew()
- )
- return self.ca_final_callbacks_var
- def _format_side_effect_message(self, var: VariableTracker) -> str:
- """Format a side effect log message with user stack."""
- assert config.side_effect_replay_policy != "silent"
- locations = self.mutation_user_stacks.get(var, [])
- description = f"Mutating object of type {var.python_type_name()}"
- source_info = " (no source)"
- if var.source is not None:
- if isinstance(var.source, TempLocalSource):
- source_info = " (source: created in torch.compile region)"
- elif isinstance(var, variables.CellVariable) and var.local_name is not None:
- source_info = f" (source: {var.local_name})"
- elif isinstance(
- var, variables.torch_function.TorchFunctionModeStackVariable
- ):
- source_info = " (source: torch function mode stack mutation)"
- else:
- # NOTE: NotImplementedError from var.source.name is a bug and must be fixed!
- source_info = f" (source name: {var.source.name})"
- if locations:
- # Format and dedupe stacks using tuple representation for efficiency
- seen = set()
- unique_formatted_stacks: list[str] = []
- stack_above_dynamo = collapse_resume_frames(get_stack_above_dynamo())
- for stack in locations:
- # Use tuple of frame info for fast deduplication
- # Include position info (colno, end_lineno, end_colno) to distinguish
- # multiple mutations on the same line (when available in Python 3.11+)
- stack_tuple = tuple(
- (
- f.filename,
- f.lineno,
- f.name,
- f.line,
- getattr(f, "colno", None),
- getattr(f, "end_lineno", None),
- getattr(f, "end_colno", None),
- )
- for f in stack
- )
- if stack_tuple not in seen:
- seen.add(stack_tuple)
- stack_augmented = collapse_resume_frames(stack_above_dynamo + stack)
- unique_formatted_stacks.append(
- "".join(traceback.format_list(stack_augmented))
- )
- formatted_lines: str = "\n********\n\n".join(unique_formatted_stacks)
- log_str = f"{description}{source_info}\n\n{textwrap.indent(formatted_lines, ' ')}"
- else:
- log_str = (
- f"{description}{source_info} (unable to find user stacks for mutations)"
- )
- return log_str
- def codegen_update_mutated(
- self, cg: PyCodegen, log_side_effects: bool = False
- ) -> None:
- side_effect_messages: list[str] = []
- # NOTE: should only be called once per VT - only if a side effect actually gets codegen'd!
- def _maybe_log_side_effect(var: VariableTracker) -> None:
- if config.side_effect_replay_policy != "silent" and log_side_effects:
- msg = self._format_side_effect_message(var)
- side_effect_messages.append(msg)
- # Log individual side effects for granular debugging
- side_effects_log.debug(msg)
- suffixes = []
- for var in self._get_modified_vars():
- # When replay_side_effects=False, only update variables with TempLocalSource
- if not config.replay_side_effects and not isinstance(
- var.source, TempLocalSource
- ):
- continue
- if isinstance(var, variables.ListVariable):
- # old[:] = new
- cg(var, allow_cache=False) # Don't codegen via source
- cg(var.source) # type: ignore[attr-defined]
- cg.extend_output(
- [
- cg.create_load_const(None),
- cg.create_load_const(None),
- create_instruction("BUILD_SLICE", arg=2),
- ]
- )
- suffixes.append([create_instruction("STORE_SUBSCR")])
- _maybe_log_side_effect(var)
- elif isinstance(var, variables.lists.DequeVariable):
- # For limited maxlen, the order of operations matter for side
- # effect, but we currently don't track the order, so no support.
- if not var.maxlen.is_constant_none():
- unimplemented(
- gb_type="Side effect on existing deque with limited maxlen",
- context="",
- explanation="This is not supported.",
- hints=[
- "Don't use a deque with `maxlen` specified.",
- ],
- )
- # old.extend(new), this runs last
- cg(var.source)
- cg.load_method("extend")
- cg(var, allow_cache=False) # Don't codegen via source
- suffixes.append(
- [
- *create_call_method(1),
- create_instruction("POP_TOP"),
- ]
- )
- # old.clear(), this runs first
- cg(var.source)
- cg.load_method("clear")
- suffixes.append(
- [
- *create_call_method(0),
- create_instruction("POP_TOP"),
- ]
- )
- _maybe_log_side_effect(var)
- elif isinstance(var, variables.ConstDictVariable):
- # Reconstruct works as follow:
- # (1) Skip codegen if there are no new items
- # (2) codegen(...) each pair of key/value
- # (3) create a new dictionary with the pairs of key/values above
- # (4) clear the original dictionary
- # + only if a key was removed from the input dict
- # (5) update the original dictionary with the dict created in (2)
- if var.has_new_items():
- cg(var.source) # type: ignore[attr-defined]
- cg.load_method("update")
- cg(var, allow_cache=False) # Don't codegen via source
- if var.should_reconstruct_all:
- cg(var.source) # type: ignore[attr-defined]
- cg.load_method("clear")
- suffixes.append(
- [
- *create_call_method(1), # update
- create_instruction("POP_TOP"),
- ]
- )
- if var.should_reconstruct_all:
- # clear will appear before "update" as the suffixes are
- # applied in reverse order.
- suffixes.append(
- [
- *create_call_method(0), # clear
- create_instruction("POP_TOP"),
- ]
- )
- _maybe_log_side_effect(var)
- elif isinstance(
- var, variables.torch_function.TorchFunctionModeStackVariable
- ):
- # Needed in the finally block for stack restoration
- cg.add_push_null(
- lambda: cg.load_import_from(
- utils.__name__, "get_torch_function_mode_stack"
- )
- )
- cg.call_function(0, False)
- name = variables.torch_function.get_prev_stack_var_name()
- cg.code_options["co_varnames"] += (name,)
- cg.append_output(create_instruction("STORE_FAST", argval=name))
- cg.add_push_null(
- lambda: cg.load_import_from(
- utils.__name__, "set_torch_function_mode_stack"
- )
- )
- cg.foreach(var.symbolic_stack)
- cg.append_output(
- create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
- )
- cg.call_function(1, False)
- cg.append_output(create_instruction("POP_TOP"))
- _maybe_log_side_effect(var)
- elif isinstance(var, variables.CellVariable) and var.local_name is not None:
- # Emit more readable and performant bytecode.
- # TODO generalize this for cells created during inlining.
- if var in self.store_attr_mutations:
- contents_var = self.load_cell(var)
- cg(contents_var)
- suffixes.append([cg.create_store_deref(var.local_name)])
- _maybe_log_side_effect(var)
- elif self.is_attribute_mutation(var):
- if isinstance(
- var,
- variables.UserDefinedDictVariable,
- ) and self.is_modified(var._dict_vt):
- # Do dict related update manually here. The store_attr
- # mutations will be applied later.
- varname_map = {}
- for name in _manual_dict_setitem.__code__.co_varnames:
- varname_map[name] = cg.tx.output.new_var()
- try:
- mro_index = type(var.value).__mro__.index(
- collections.OrderedDict
- )
- except ValueError:
- mro_index = type(var.value).__mro__.index(dict)
- cg.extend_output(
- [
- create_instruction("LOAD_CONST", argval=mro_index),
- create_instruction(
- "STORE_FAST", argval=varname_map["mro_index"]
- ),
- ]
- )
- cg(var.source) # type: ignore[attr-defined]
- cg.extend_output(
- [
- create_instruction(
- "STORE_FAST", argval=varname_map["dict_to"]
- )
- ]
- )
- cg(var._dict_vt, allow_cache=False) # Don't codegen via source
- cg.extend_output(
- [
- create_instruction(
- "STORE_FAST", argval=varname_map["dict_from"]
- )
- ]
- )
- dict_update_insts = bytecode_from_template(
- _manual_dict_setitem, varname_map=varname_map
- )
- suffixes.append(
- [
- *dict_update_insts,
- create_instruction("POP_TOP"),
- ]
- )
- _maybe_log_side_effect(var._dict_vt)
- elif isinstance(
- var,
- variables.UserDefinedListVariable,
- ) and self.is_modified(var._list_vt):
- # Update the list to the updated items. Be careful in
- # calling the list methods and not the overridden methods.
- varname_map = {}
- for name in _manual_list_update.__code__.co_varnames:
- varname_map[name] = cg.tx.output.new_var()
- cg(var.source) # type: ignore[attr-defined]
- cg.extend_output(
- [
- create_instruction(
- "STORE_FAST", argval=varname_map["list_to"]
- )
- ]
- )
- cg(var._list_vt, allow_cache=False) # Don't codegen via source
- cg.extend_output(
- [
- create_instruction(
- "STORE_FAST", argval=varname_map["list_from"]
- )
- ]
- )
- list_update_insts = bytecode_from_template(
- _manual_list_update, varname_map=varname_map
- )
- suffixes.append(
- [
- *list_update_insts,
- create_instruction("POP_TOP"),
- ]
- )
- _maybe_log_side_effect(var._list_vt)
- # Applying mutations involves two steps: 1) Push all
- # reconstructed objects onto the stack. 2) Call STORE_ATTR to
- # apply the mutations.
- #
- # Dynamo must ensure that mutations are applied in the same
- # order as in the original program. Therefore, two reverse
- # operations occur below.
- #
- # The first reverse operation concerns `suffixes`. We apply
- # suffixes in reverse order due to the way Python handles the
- # stack. In Step 1, we push all reconstructed objects onto the
- # stack, but the item at the top of the stack refers to the last
- # attribute in the mutation order. If not fixed, this will apply
- # the mutations of attributes in the reverse order. To account
- # for this reversal, we iterate through the mutable attributes
- # in reverse order.
- side_effect_occurred = False
- for name, value in reversed(
- self.store_attr_mutations.get(var, {}).items()
- ):
- if isinstance(var, variables.NewGlobalVariable):
- cg.tx.output.update_co_names(name)
- cg(value)
- assert isinstance(var.source, GlobalSource) # type: ignore[attr-defined]
- suffixes.append(
- [create_instruction("STORE_GLOBAL", argval=name)]
- )
- side_effect_occurred = True
- elif isinstance(value, variables.DeletedVariable):
- if isinstance(
- var.mutation_type, AttributeMutationExisting
- ) and hasattr(getattr(var, "value", None), name):
- cg.tx.output.update_co_names(name)
- cg(var.source)
- suffixes.append(
- [create_instruction("DELETE_ATTR", argval=name)]
- )
- side_effect_occurred = True
- elif isinstance(
- var, variables.UserDefinedObjectVariable
- ) and var.should_skip_descriptor_setter(name):
- cg.add_push_null(
- lambda: cg.load_import_from(
- utils.__name__, "object_setattr_ignore_descriptor"
- )
- )
- cg(var.source) # type: ignore[attr-defined]
- cg(variables.ConstantVariable(name))
- cg(value)
- suffixes.append(
- [
- *create_call_function(3, False),
- create_instruction("POP_TOP"),
- ]
- )
- side_effect_occurred = True
- elif (
- isinstance(var, variables.UserDefinedObjectVariable)
- and var.needs_slow_setattr()
- ):
- # __setattr__ is defined on this object, so call object.__setattr__ directly
- cg.load_import_from("builtins", "object")
- cg.load_method("__setattr__")
- cg(var.source) # type: ignore[attr-defined]
- cg(variables.ConstantVariable(name))
- cg(value)
- suffixes.append(
- [*create_call_method(3), create_instruction("POP_TOP")]
- )
- side_effect_occurred = True
- else:
- cg.tx.output.update_co_names(name)
- cg(value)
- cg(var)
- suffixes.append([create_instruction("STORE_ATTR", argval=name)])
- side_effect_occurred = True
- if side_effect_occurred:
- _maybe_log_side_effect(var)
- elif isinstance(var, variables.ListIteratorVariable):
- for _ in range(var.index):
- cg.add_push_null(
- lambda: cg.load_import_from(utils.__name__, "iter_next")
- )
- cg(var.source) # type: ignore[attr-defined]
- cg.call_function(1, False)
- cg.pop_top()
- _maybe_log_side_effect(var)
- elif isinstance(var, variables.RandomVariable):
- # set correct random seed state
- def gen_fn() -> None:
- cg(var.source) # type: ignore[attr-defined]
- cg.load_attr("setstate")
- cg.add_push_null(gen_fn)
- cg(var.wrap_state(var.random.getstate()))
- suffixes.append(
- [
- *create_call_function(1, False), # setstate
- create_instruction("POP_TOP"),
- ]
- )
- _maybe_log_side_effect(var)
- else:
- raise AssertionError(type(var))
- # do all the actual mutations at the very end to handle dependencies
- for suffix in reversed(suffixes):
- cg.extend_output(suffix)
- # Send batched structured trace for all side effects in this compilation
- if log_side_effects and side_effect_messages:
- combined_msg = "\n\n========================================\n\n".join(
- side_effect_messages
- )
- torch._logging.trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "dynamo_side_effects",
- "encoding": "string",
- },
- payload_fn=lambda: combined_msg,
- )
- def is_empty(self) -> bool:
- return not (
- any(map(self.is_modified, self.id_to_variable.values()))
- or self.tensor_hooks
- or self.save_for_backward
- or self.tensor_hooks
- )
- def clear(self) -> None:
- self.keepalive.clear()
- self.id_to_variable.clear()
- @contextlib.contextmanager
- def allow_side_effects_in_hop(
- tx: "InstructionTranslatorBase",
- ) -> Generator[None, None, None]:
- """Context manager to temporarily allow side effects with extra outputs.
- This is used for special cases (like FSDP functions) that need to perform
- side effects even when the general policy is to disallow them.
- """
- orig_val = tx.output.current_tracer.allow_side_effects_in_hop
- try:
- tx.output.current_tracer.allow_side_effects_in_hop = True
- yield
- finally:
- tx.output.current_tracer.allow_side_effects_in_hop = orig_val
- @contextlib.contextmanager
- def allow_externally_visible_side_effects_in_subtracer(
- tx: "InstructionTranslatorBase",
- ) -> Generator[None, None, None]:
- orig_val = tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
- try:
- tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = True
- tx.output.current_tracer.traced_with_externally_visible_side_effects = True
- yield
- finally:
- tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = orig_val
- @contextlib.contextmanager
- def disallow_side_effects_in_generator(
- tx: "InstructionTranslatorBase",
- ) -> Generator[None, None, None]:
- orig_val = tx.output.current_tracer.is_reconstructing_generator
- try:
- tx.output.current_tracer.is_reconstructing_generator = True
- yield
- finally:
- tx.output.current_tracer.is_reconstructing_generator = orig_val
|