side_effects.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374
  1. """
  2. Side effect tracking and management for TorchDynamo's compilation system.
  3. This module provides infrastructure for tracking and managing side effects that occur
  4. during symbolic execution, including:
  5. - Tracking mutations to objects, attributes, and variables
  6. - Managing context changes (cell variables, global namespace modifications)
  7. - Handling aliasing and object identity preservation
  8. - Managing stack frame state and local variable changes
  9. - Tracking function calls with side effects
  10. Key classes:
  11. - SideEffects: Main container for tracking all side effects during execution
  12. - MutableSideEffects: Specialization for mutable object tracking
  13. - AttributeMutation/ValueMutation: Track specific types of mutations
  14. - Various specialized side effect classes for different scenarios
  15. The side effect system ensures that mutations performed during symbolic execution
  16. are properly replayed during runtime, maintaining the correctness of compiled code
  17. while enabling optimizations where safe.
  18. """
  19. import collections
  20. import contextlib
  21. import inspect
  22. import textwrap
  23. import traceback
  24. import warnings
  25. import weakref
  26. from collections.abc import Generator, MutableMapping
  27. from types import CellType
  28. from typing import Any, Optional, TYPE_CHECKING
  29. import torch
  30. import torch.nn
  31. from torch._dynamo.variables.misc import AutogradFunctionContextVariable
  32. from . import config, graph_break_hints, utils, variables
  33. from .bytecode_transformation import (
  34. bytecode_from_template,
  35. create_call_function,
  36. create_call_method,
  37. create_instruction,
  38. )
  39. from .codegen import PyCodegen
  40. from .exc import collapse_resume_frames, get_stack_above_dynamo, unimplemented
  41. from .source import GlobalSource, LocalCellSource, Source, TempLocalSource
  42. from .utils import is_frozen_dataclass, nn_module_new, object_new
  43. from .variables.base import (
  44. AttributeMutation,
  45. AttributeMutationExisting,
  46. AttributeMutationNew,
  47. is_side_effect_safe,
  48. ValueMutationExisting,
  49. ValueMutationNew,
  50. VariableTracker,
  51. )
  52. from .variables.user_defined import FrozenDataClassVariable
  53. if TYPE_CHECKING:
  54. from torch._dynamo.output_graph import OutputGraph
  55. from torch._dynamo.symbolic_convert import InstructionTranslatorBase
  56. from torch._dynamo.variables.lists import ListVariable
  57. side_effects_log = torch._logging.getArtifactLogger(__name__, "side_effects")
  58. def _manual_dict_setitem(
  59. dict_from: dict[Any, Any], dict_to: dict[Any, Any], mro_index: int
  60. ) -> None:
  61. # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
  62. # to be careful because we don't want to trigger the user defined object
  63. # setitem or clear. The mro_index is used to find the dict/OrderedDict from
  64. # the class mro.
  65. dict_class = type(dict_to).__mro__[mro_index]
  66. dict_class.clear(dict_to) # type: ignore[attr-defined]
  67. for k, v in dict_from.items():
  68. dict_class.__setitem__(dict_to, k, v) # type: ignore[index]
  69. def _manual_list_update(list_from: list[Any], list_to: list[Any]) -> None:
  70. list.clear(list_to)
  71. list.extend(list_to, list_from)
  72. class SideEffects:
  73. """
  74. Maintain records of mutations and provide methods to apply them during code generation.
  75. Handles tracking and applying side effects during PyTorch Dynamo compilation,
  76. maintaining Python semantics by managing mutations, attribute modifications,
  77. and other side effects that occur during program execution.
  78. Key responsibilities:
  79. - Tracks mutations to Python objects, lists, and dictionaries that need to be
  80. applied after an FX graph is run.
  81. - Manages attribute modifications and deletions
  82. - Handles tensor hooks and backward pass state
  83. - Tracks cell variable mutations and global variable changes
  84. - Ensures correct ordering and application of side effects after graph execution
  85. This ensures that optimized code behaves identically to the original Python code with
  86. respect to object mutations and other side effects.
  87. """
  88. id_to_variable: dict[int, VariableTracker]
  89. store_attr_mutations: dict[VariableTracker, dict[str, VariableTracker]]
  90. keepalive: list[Any]
  91. # Maps variable tracker to list of user stacks (StackSummary objects, formatted lazily)
  92. mutation_user_stacks: dict[VariableTracker, list[traceback.StackSummary]]
  93. def __init__(
  94. self,
  95. output_graph: "OutputGraph",
  96. id_to_variable: Optional[dict[int, VariableTracker]] = None,
  97. store_attr_mutations: Optional[
  98. dict[VariableTracker, dict[str, VariableTracker]]
  99. ] = None,
  100. mutation_user_stacks: dict[VariableTracker, list[traceback.StackSummary]]
  101. | None = None,
  102. keepalive: Optional[list[Any]] = None,
  103. save_for_backward: Optional[
  104. list[tuple[AutogradFunctionContextVariable, list[VariableTracker]]]
  105. ] = None,
  106. tensor_hooks: Optional[
  107. dict[
  108. int,
  109. tuple[
  110. "variables.TensorVariable",
  111. VariableTracker,
  112. "variables.RemovableHandleVariable",
  113. str,
  114. ],
  115. ]
  116. ] = None,
  117. ) -> None:
  118. super().__init__()
  119. self.output_graph_weakref = weakref.ref(output_graph)
  120. self.id_to_variable = id_to_variable or {}
  121. self.store_attr_mutations = store_attr_mutations or {}
  122. self.mutation_user_stacks = mutation_user_stacks or {}
  123. self.keepalive = keepalive or []
  124. self.save_for_backward = save_for_backward or []
  125. self.tensor_hooks = tensor_hooks or {}
  126. # Used by MappingProxyVariable to graph break in case of any mutated
  127. # dict
  128. self._has_existing_dict_mutation = False
  129. # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph.
  130. # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd.
  131. self.ca_final_callbacks_var: Optional[ListVariable] = None
  132. # Tracks VariableTracker objects whose mutations can be skipped.
  133. # For normal mutated variables, Dynamo generates code to replay/reconstruct
  134. # the mutations after graph execution. However, variables in this set have
  135. # their mutations ignored - the mutations happen during
  136. # execution but don't need to be replayed in the generated code.
  137. # Used for temporary mutations in contexts like torch.func.functional_call,
  138. # where module parameters/buffers are modified but later restored.
  139. self.ignore_mutation_on_these_variables: set[VariableTracker] = set()
  140. def ignore_mutations_on(self, var: VariableTracker) -> None:
  141. """Mutations to this variable will be executed but not not tracked,
  142. typically used for temporary mutations that are later restored."""
  143. self.ignore_mutation_on_these_variables.add(var)
  144. def stop_ignoring_mutations_on(self, var: VariableTracker) -> None:
  145. """Remove a variable from the skip mutation set, restoring normal mutation tracking."""
  146. if var in self.ignore_mutation_on_these_variables:
  147. self.ignore_mutation_on_these_variables.remove(var)
  148. def _capture_user_stack(self, key: VariableTracker) -> None:
  149. """Capture the current user stack from the instruction translator."""
  150. if config.side_effect_replay_policy == "silent":
  151. return
  152. if key not in self.mutation_user_stacks:
  153. self.mutation_user_stacks[key] = []
  154. self.mutation_user_stacks[key].append(
  155. torch._guards.TracingContext.extract_stack()
  156. )
  157. def __eq__(self, other: object) -> bool:
  158. assert isinstance(other, SideEffects)
  159. # NB: do NOT test keepalive
  160. return (
  161. self.id_to_variable == other.id_to_variable
  162. and self.store_attr_mutations == other.store_attr_mutations
  163. and self.save_for_backward == other.save_for_backward
  164. and self.tensor_hooks == other.tensor_hooks
  165. )
  166. def diff(self, other: "SideEffects") -> Optional[str]:
  167. if self.id_to_variable != other.id_to_variable:
  168. sk_itv = self.id_to_variable.keys()
  169. ok_itv = other.id_to_variable.keys()
  170. if sk_itv != ok_itv:
  171. return f"id_to_variable keys: {sk_itv} != {ok_itv}"
  172. # Feel free to augment this with more fancy diffing logic
  173. # if needed for debugging
  174. return "id_to_variable: unknown diff"
  175. elif self.store_attr_mutations != other.store_attr_mutations:
  176. sk_sam = self.store_attr_mutations.keys()
  177. ok_sam = other.store_attr_mutations.keys()
  178. if sk_sam != ok_sam:
  179. return f"store_attr_mutations keys: {sk_sam} != {ok_sam}"
  180. return "store_attr_mutations: unknown diff"
  181. elif self.save_for_backward != other.save_for_backward:
  182. return "save_for_backward"
  183. elif self.tensor_hooks != other.tensor_hooks:
  184. return "tensor_hooks"
  185. else:
  186. return None
  187. def clone(self) -> "SideEffects":
  188. """Create a shallow copy"""
  189. ref = self.output_graph_weakref()
  190. assert ref is not None
  191. return self.__class__(
  192. output_graph=ref,
  193. id_to_variable=dict(self.id_to_variable),
  194. store_attr_mutations={
  195. k: dict(v) for k, v in self.store_attr_mutations.items()
  196. },
  197. mutation_user_stacks=self.mutation_user_stacks,
  198. keepalive=list(self.keepalive),
  199. save_for_backward=self.save_for_backward,
  200. tensor_hooks=self.tensor_hooks,
  201. )
  202. def __contains__(self, item: Any) -> bool:
  203. return id(item) in self.id_to_variable
  204. def __getitem__(self, item: Any) -> VariableTracker:
  205. return self.id_to_variable[id(item)]
  206. def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool:
  207. output_graph = self.output_graph_weakref()
  208. return bool(
  209. output_graph
  210. and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
  211. )
  212. def should_allow_side_effects_in_hop(self) -> bool:
  213. output_graph = self.output_graph_weakref()
  214. return bool(
  215. output_graph
  216. and output_graph.current_tx.output.current_tracer.allow_side_effects_in_hop
  217. )
  218. def is_reconstructing_generator(self) -> bool:
  219. output_graph = self.output_graph_weakref()
  220. return bool(
  221. output_graph
  222. and output_graph.current_tx.output.current_tracer.is_reconstructing_generator
  223. )
  224. def check_allowed_side_effect(self, item: VariableTracker) -> bool:
  225. from torch._dynamo.variables.misc import AutogradFunctionContextVariable
  226. # People do things like self.dim = dim inside autograd.Function.
  227. # These are benign.
  228. if isinstance(item, AutogradFunctionContextVariable):
  229. return True
  230. if self.should_allow_externally_visible_side_effects_in_subtracer():
  231. return True
  232. if self.should_allow_side_effects_in_hop():
  233. return True
  234. if self.is_reconstructing_generator():
  235. # This is missing the case where one mutates a tensor. See
  236. # test_generator.py::test_reconstruct_generator_tensor_mutation
  237. unimplemented(
  238. gb_type="Generator reconstruction with mutations",
  239. context=f"mutating object: {item}",
  240. explanation="Cannot reconstruct a generator with variable mutations. "
  241. "Dynamo needs to fully exhaust the generator, which may cause "
  242. "unintended variable modifications.",
  243. hints=[
  244. "Remove mutations from the generator.",
  245. *graph_break_hints.FUNDAMENTAL,
  246. ],
  247. )
  248. assert item.mutation_type is not None
  249. if not is_side_effect_safe(item.mutation_type):
  250. unimplemented(
  251. gb_type="HOP: Unsafe side effect",
  252. context=f"Attempted to mutate {item}",
  253. explanation="Mutating a variable from outside the scope of this HOP is not supported.",
  254. hints=[
  255. "If the HOP is activation checkpointing (torch.utils.checkpoint.checkpoint), this points to a "
  256. "side effect in forward method. Eager activation checkpointing replays that side-effect while "
  257. "recomputing the forward in the backward. If you are ok with side-effect not replayed in the "
  258. "backward, try setting `torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = True`",
  259. ],
  260. )
  261. return False
  262. def store_attr(
  263. self, item: VariableTracker, name: str, value: VariableTracker
  264. ) -> None:
  265. assert self.is_attribute_mutation(item)
  266. self.check_allowed_side_effect(item)
  267. if item not in self.store_attr_mutations:
  268. self.store_attr_mutations[item] = {}
  269. self.store_attr_mutations[item][name] = value
  270. # Capture user stack for this mutation
  271. self._capture_user_stack(item)
  272. def load_attr(
  273. self,
  274. item: VariableTracker,
  275. name: str,
  276. deleted_ok: bool = False,
  277. check: bool = False,
  278. ) -> VariableTracker:
  279. if check:
  280. assert self.is_attribute_mutation(item)
  281. result = self.store_attr_mutations[item][name]
  282. if not deleted_ok and isinstance(result, variables.DeletedVariable):
  283. unimplemented(
  284. gb_type="Attempted to read a deleted variable",
  285. context=f"item: {item}, name: {name}",
  286. explanation="",
  287. hints=[*graph_break_hints.USER_ERROR],
  288. )
  289. return result
  290. def store_cell(self, cellvar: VariableTracker, value: VariableTracker) -> None:
  291. if cellvar.is_immutable():
  292. unimplemented(
  293. gb_type="Write to immutable cell",
  294. context=f"cellvar: {cellvar}, value: {value}",
  295. explanation="Dynamo doesn't support writing to immutable/sourceless cell variables.",
  296. hints=[*graph_break_hints.DIFFICULT],
  297. )
  298. assert isinstance(cellvar, variables.CellVariable)
  299. assert isinstance(value, variables.VariableTracker)
  300. self.store_attr(cellvar, "cell_contents", value)
  301. def load_cell(self, cellvar: VariableTracker) -> VariableTracker:
  302. assert isinstance(cellvar, variables.CellVariable)
  303. if self.has_pending_mutation_of_attr(cellvar, "cell_contents"):
  304. return self.load_attr(cellvar, "cell_contents", check=False)
  305. if cellvar.pre_existing_contents:
  306. return cellvar.pre_existing_contents
  307. unimplemented(
  308. gb_type="Read uninitialized cell",
  309. context=str(cellvar),
  310. explanation="Attempted to read a cell variable that has not been populated yet.",
  311. hints=[*graph_break_hints.USER_ERROR],
  312. )
  313. def load_global(self, gvar: VariableTracker, name: str) -> VariableTracker:
  314. assert isinstance(gvar, variables.VariableTracker)
  315. return self.load_attr(gvar, name)
  316. def store_global(
  317. self, gvar: VariableTracker, name: str, value: VariableTracker
  318. ) -> None:
  319. assert isinstance(gvar, variables.VariableTracker)
  320. assert isinstance(value, variables.VariableTracker)
  321. self.store_attr(gvar, name, value)
  322. @staticmethod
  323. def cls_supports_mutation_side_effects(cls: type) -> bool:
  324. return inspect.getattr_static(cls, "__getattribute__", None) in (
  325. object.__getattribute__,
  326. dict.__getattribute__,
  327. set.__getattribute__,
  328. frozenset.__getattribute__,
  329. int.__getattribute__,
  330. str.__getattribute__,
  331. list.__getattribute__,
  332. tuple.__getattribute__,
  333. BaseException.__getattribute__,
  334. )
  335. def is_attribute_mutation(self, item: VariableTracker) -> bool:
  336. return isinstance(item.mutation_type, AttributeMutation)
  337. def has_pending_mutation(self, item: VariableTracker) -> bool:
  338. return self.is_attribute_mutation(item) and bool(
  339. self.store_attr_mutations.get(item)
  340. )
  341. def has_pending_mutation_of_attr(self, item: VariableTracker, name: str) -> bool:
  342. return self.is_attribute_mutation(
  343. item
  344. ) and name in self.store_attr_mutations.get(item, ())
  345. def is_modified(self, item: VariableTracker) -> bool:
  346. if item.is_immutable():
  347. return False
  348. if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)):
  349. return True
  350. if isinstance(item, variables.UserDefinedObjectVariable):
  351. # Checks if the underlying dict or tuple vt has been modified
  352. return item in self.store_attr_mutations or item.is_underlying_vt_modified(
  353. self
  354. )
  355. if self.is_attribute_mutation(item):
  356. return item in self.store_attr_mutations
  357. assert item.mutation_type is not None
  358. return item.mutation_type.is_modified # type: ignore[attr-defined]
  359. def _track_obj(
  360. self,
  361. item: Any,
  362. variable: VariableTracker,
  363. mutation_type_cls: type = ValueMutationExisting,
  364. ) -> VariableTracker:
  365. """Start tracking an existing or new variable for mutation"""
  366. if id(item) in self.id_to_variable:
  367. raise AssertionError(
  368. f"{variable} is already tracked for mutation. This could be "
  369. "because you are not using VariableBuilder to construct "
  370. "the variable tracker. "
  371. f"Source of new object: {variable.source}. "
  372. f"Source of previously tracked object: {self.id_to_variable[id(item)].source}."
  373. )
  374. variable.mutation_type = mutation_type_cls()
  375. self.id_to_variable[id(item)] = variable
  376. self.keepalive.append(item)
  377. return variable
  378. track_mutable = _track_obj
  379. def track_object_existing(
  380. self,
  381. item: Any,
  382. variable: VariableTracker,
  383. ) -> VariableTracker:
  384. # TODO: Modify this API so that we preserve type info of
  385. # variable
  386. return self._track_obj(
  387. item,
  388. variable,
  389. mutation_type_cls=AttributeMutationExisting,
  390. )
  391. def track_object_new(
  392. self,
  393. cls_source: Source | None,
  394. user_cls: Any,
  395. variable_cls: Any,
  396. options: dict[str, Any],
  397. ) -> VariableTracker:
  398. if user_cls is torch.autograd.function.FunctionCtx:
  399. with warnings.catch_warnings(record=True):
  400. obj = torch.autograd.Function()
  401. else:
  402. obj = object_new(user_cls)
  403. variable = variable_cls(
  404. obj,
  405. mutation_type=AttributeMutationNew(cls_source),
  406. **options,
  407. )
  408. self.id_to_variable[id(obj)] = variable
  409. self.keepalive.append(obj)
  410. return variable
  411. def get_variable_cls(self, user_cls: type) -> type:
  412. from torch.overrides import TorchFunctionMode
  413. from .variables.ctx_manager import GenericContextWrappingVariable
  414. from .variables.torch_function import TorchFunctionModeVariable
  415. from .variables.user_defined import is_forbidden_context_manager
  416. variable_cls: type[variables.UserDefinedObjectVariable] = (
  417. variables.UserDefinedObjectVariable
  418. )
  419. if issubclass(
  420. user_cls, TorchFunctionMode
  421. ) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls):
  422. variable_cls = TorchFunctionModeVariable
  423. elif (
  424. hasattr(user_cls, "__enter__")
  425. and hasattr(user_cls, "__exit__")
  426. and not is_forbidden_context_manager(user_cls)
  427. ):
  428. variable_cls = GenericContextWrappingVariable
  429. elif issubclass(user_cls, torch.nn.Module):
  430. variable_cls = variables.UnspecializedNNModuleVariable
  431. elif issubclass(user_cls, (dict, collections.OrderedDict)):
  432. variable_cls = variables.UserDefinedDictVariable
  433. elif issubclass(user_cls, (set, frozenset)):
  434. variable_cls = variables.UserDefinedSetVariable
  435. elif issubclass(user_cls, tuple):
  436. variable_cls = variables.UserDefinedTupleVariable
  437. elif issubclass(user_cls, list):
  438. variable_cls = variables.UserDefinedListVariable
  439. elif issubclass(user_cls, MutableMapping):
  440. variable_cls = variables.MutableMappingVariable
  441. elif is_frozen_dataclass(user_cls):
  442. variable_cls = FrozenDataClassVariable
  443. elif issubclass(user_cls, BaseException):
  444. variable_cls = variables.UserDefinedExceptionObjectVariable
  445. elif variables.InspectVariable.is_matching_class(user_cls):
  446. variable_cls = variables.InspectVariable
  447. assert issubclass(variable_cls, variables.UserDefinedObjectVariable)
  448. return variable_cls
  449. def get_example_value(
  450. self,
  451. base_cls_vt: VariableTracker,
  452. cls_vt: VariableTracker,
  453. init_args: list[VariableTracker],
  454. ) -> Any:
  455. user_cls = cls_vt.value # type: ignore[attr-defined]
  456. if issubclass(user_cls, torch.nn.Module):
  457. # TODO(anijain2305) - Is it possible to remove this specialization?
  458. obj = nn_module_new(user_cls)
  459. else:
  460. if isinstance(base_cls_vt, variables.BuiltinVariable):
  461. base_cls = base_cls_vt.fn
  462. elif isinstance(base_cls_vt, variables.UserDefinedClassVariable):
  463. base_cls = base_cls_vt.value
  464. else:
  465. raise RuntimeError(f"Unexpected base_cls_vt {base_cls_vt}")
  466. assert variables.UserDefinedClassVariable.is_supported_new_method(
  467. base_cls.__new__
  468. )
  469. # TODO(anijain2305) - Consider adding get_example_value method to
  470. # each VT to get an example value for all args. As we expand the
  471. # scope to other __new__ methods, we might need to call __new__ with
  472. # init_args (like functools.partial)
  473. # init_args = [arg.get_example_value() for arg in init_args]
  474. # obj = base_cls.__new__(user_cls, *init_args)
  475. obj = base_cls.__new__(user_cls)
  476. return obj
  477. def track_new_user_defined_object(
  478. self,
  479. base_cls_vt: VariableTracker,
  480. cls_vt: VariableTracker,
  481. init_args: list[VariableTracker],
  482. ) -> VariableTracker:
  483. """
  484. Creates a UserDefinedObjectVariable (or its subclass) variable tracker
  485. and mark it for attribute mutation tracking.
  486. Also records the variable trackers to call __new__ method on
  487. reconstruction. Roughly, the reconstruction looks like this
  488. base_cls_vt.__new__(user_cls, *init_args)
  489. """
  490. cls_source = cls_vt.source
  491. user_cls = cls_vt.value # type: ignore[attr-defined]
  492. variable_cls = self.get_variable_cls(user_cls)
  493. obj = self.get_example_value(base_cls_vt, cls_vt, init_args)
  494. variable = variable_cls(
  495. obj,
  496. cls_source=cls_vt.source,
  497. base_cls_vt=base_cls_vt,
  498. init_args=init_args,
  499. mutation_type=AttributeMutationNew(cls_source),
  500. )
  501. self.id_to_variable[id(obj)] = variable
  502. self.keepalive.append(obj)
  503. return variable
  504. def track_cell_new(
  505. self,
  506. ) -> VariableTracker:
  507. obj = object()
  508. variable = variables.CellVariable(
  509. mutation_type=AttributeMutationNew(),
  510. )
  511. self.id_to_variable[id(obj)] = variable
  512. self.keepalive.append(obj)
  513. return variable
  514. def track_cell_existing(
  515. self, source: Optional[Source], cell: CellType, contents: VariableTracker
  516. ) -> VariableTracker:
  517. variable = variables.CellVariable(
  518. # We don't support mutation to cell without source because we need
  519. # source to properly codegen the mutations.
  520. mutation_type=None if source is None else AttributeMutationExisting(),
  521. pre_existing_contents=contents,
  522. source=source,
  523. )
  524. self.id_to_variable[id(cell)] = variable
  525. self.keepalive.append(cell)
  526. return variable
  527. def track_global_existing(self, source: Source, item: Any) -> VariableTracker:
  528. variable = variables.NewGlobalVariable(
  529. mutation_type=AttributeMutationExisting(),
  530. source=source,
  531. )
  532. self.id_to_variable[id(item)] = variable
  533. self.keepalive.append(item)
  534. return variable
  535. def track_save_for_backward(
  536. self, ctx: VariableTracker, args: list[VariableTracker]
  537. ) -> None:
  538. assert isinstance(ctx, variables.AutogradFunctionContextVariable)
  539. self.save_for_backward.append((ctx, args))
  540. def track_runahead_tensor_and_symvar_side_effects(
  541. self, other: "SideEffects"
  542. ) -> None:
  543. # In higher order ops we want to keep track of tensors seen in the
  544. # speculate_subgraph so that we don't lift them again as a new input in
  545. # other speculate_subgraph or in the root tracer.
  546. for other_item in other.keepalive:
  547. other_id = id(other_item)
  548. other_variable = other.id_to_variable[other_id]
  549. if other_id not in self.id_to_variable and isinstance(
  550. other_variable, (variables.TensorVariable, variables.SymNodeVariable)
  551. ):
  552. self.track_object_existing(other_item, other_variable)
  553. def prune_dead_object_new(self, tx: "InstructionTranslatorBase") -> None:
  554. # Avoid VT cycles from e.g., recursive function.
  555. visited: set[VariableTracker] = set()
  556. live_new_objects: set[VariableTracker] = set()
  557. def visit(var: VariableTracker) -> None:
  558. if var in visited:
  559. return
  560. visited.add(var)
  561. # Object may have been mutated, store this mutation.
  562. if isinstance(var.mutation_type, AttributeMutationNew):
  563. live_new_objects.add(var)
  564. # It's possible that we have mutated the value of this variable
  565. # to be another one. The new value is in store_attr_mutations.
  566. # Also recurse through the new value to detect alive AttributeMutationNew.
  567. if var in self.store_attr_mutations:
  568. VariableTracker.visit(
  569. visit, # noqa: F821
  570. self.store_attr_mutations[var],
  571. )
  572. def is_live(var: VariableTracker) -> bool:
  573. if isinstance(var.mutation_type, AttributeMutationNew):
  574. return var in live_new_objects
  575. return True
  576. pre_existing_vars = [
  577. var
  578. for var in self.id_to_variable.values()
  579. if not isinstance(var.mutation_type, AttributeMutationNew)
  580. ]
  581. # The only live side effects come from returns (tx.stack), any intermediates
  582. # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables.
  583. # Recursively visit Variables and see if any of them have been mutated.
  584. init_live_vars = []
  585. # gather stack/symbolic_locals for all tx's up the chain
  586. cur_tx: Optional[InstructionTranslatorBase] = tx
  587. while cur_tx is not None:
  588. init_live_vars.extend([cur_tx.stack, cur_tx.symbolic_locals])
  589. if cur_tx.parent is not None:
  590. # for non-root tx'es, also keep the cells/freevars alive so they get codegen'd properly
  591. # TODO see if we could prune dead cells - cell pruning information needs to be forwarded
  592. # to the resume function creation as well.
  593. assert cur_tx.post_prune_cell_and_freevars is not None
  594. init_live_vars.append(cur_tx.post_prune_cell_and_freevars)
  595. cur_tx = cur_tx.parent
  596. VariableTracker.visit(
  597. visit,
  598. # TODO track from all possible sources.
  599. init_live_vars
  600. + [
  601. pre_existing_vars,
  602. tx.output.backward_state,
  603. self.tensor_hooks,
  604. ],
  605. )
  606. # Manually release the self-referential function, which indirectly
  607. # captures certain `VariableTracker` and affects parts of PT test/logic
  608. # that are sensitive to when certain objects get released.
  609. del visit
  610. # NB: cell variable handling.is tricky.
  611. # cell variables must stay alive if any NestedUserFunctionVariable
  612. # are live. "visit"-ing the NestedUserFunctionVariable visits
  613. # the .closures field, from which we will see if we need to keep
  614. # any mutations to cell variables alive.
  615. self.id_to_variable = {
  616. k: v for k, v in self.id_to_variable.items() if is_live(v)
  617. }
  618. self.store_attr_mutations = {
  619. k: v for k, v in self.store_attr_mutations.items() if is_live(k)
  620. }
  621. def mutation(self, var: VariableTracker) -> None:
  622. if var in self.ignore_mutation_on_these_variables:
  623. return
  624. self.check_allowed_side_effect(var)
  625. # Capture user stack for this mutation
  626. self._capture_user_stack(var)
  627. if isinstance(var.mutation_type, ValueMutationExisting):
  628. var.mutation_type.is_modified = True
  629. if (
  630. var.source
  631. and isinstance(var, variables.ConstDictVariable)
  632. and not isinstance(var, variables.SetVariable)
  633. ):
  634. self._has_existing_dict_mutation = True
  635. def has_existing_dict_mutation(self) -> bool:
  636. return self._has_existing_dict_mutation
  637. def _get_modified_vars(self) -> list[VariableTracker]:
  638. return [var for var in self.id_to_variable.values() if self.is_modified(var)]
  639. def codegen_save_tempvars(self, cg: PyCodegen) -> None:
  640. # We must codegen modified VT to their source by default, so that
  641. # mutation and aliasing are properly accounted for.
  642. #
  643. # Since newly constructed objects don't have a source, we manually
  644. # codegen their construction and store them to a newly assigned local
  645. # source. Note that `ValueMutationNew` isn't tracked by SideEffects.
  646. for var in self._get_modified_vars():
  647. if not isinstance(var.mutation_type, AttributeMutationNew):
  648. assert var.source is not None
  649. continue
  650. if isinstance(var, variables.CellVariable):
  651. # Cells created in the root frame are created either by
  652. # `MAKE_CELL` or by them being in `co_cellvars`, so we only emit
  653. # `make_cell` for the non-root-frame cells here.
  654. # TODO generalize this so we never need to call `make_cell`.
  655. if var.local_name is None:
  656. cg.add_push_null(
  657. lambda: cg.load_import_from(utils.__name__, "make_cell")
  658. )
  659. cg.extend_output(create_call_function(0, False))
  660. cg.add_cache(var)
  661. var.source = TempLocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
  662. elif var.source is None:
  663. var.source = LocalCellSource(var.local_name)
  664. elif var.is_tensor():
  665. # NOTE: for historical reasons we never assigned local sources
  666. # to newly constructed tensor object, so we keep it that way.
  667. # They are always loaded from output of the fx graph, so one can
  668. # think of it as having a "OutputGraphSource" for codegen
  669. # purposes.
  670. #
  671. # However, tensor subclass objects are different, because the
  672. # reconstruction logic in `PyCodegen` loads the data tensor from
  673. # graph output and then calls `as_subclass`, meaning we must
  674. # assign a source to it to ensure we only reconstruct one
  675. # subclass instance.
  676. if isinstance(
  677. var, variables.torch_function.TensorWithTFOverrideVariable
  678. ):
  679. # Don't codegen from temp source assigned from the 1st pass.
  680. cg(var, allow_cache=False)
  681. cg.add_cache(var)
  682. # `add_cache` generates STORE and consumes TOS, but we never
  683. # cleared it. TODO move this call into `add_cache`
  684. cg.clear_tos()
  685. var.source = TempLocalSource(cg.tempvars[var])
  686. elif isinstance(var, variables.AutogradFunctionContextVariable):
  687. unimplemented(
  688. gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region",
  689. context="",
  690. explanation="We cannot reconstruct a torch.autograd.Function's context object.",
  691. hints=[],
  692. )
  693. else:
  694. # Reconstruct the bytecode for
  695. # base_cls.__new__(user_cls, *args)
  696. if isinstance(var, variables.UserDefinedObjectVariable):
  697. def load_new_method() -> None:
  698. # pyrefly: ignore [missing-attribute]
  699. assert var.base_cls_vt is not None
  700. cg(var.base_cls_vt) # type: ignore[attr-defined]
  701. cg.extend_output([cg.create_load_attr("__new__")])
  702. cg.add_push_null(load_new_method)
  703. else:
  704. cg.add_push_null(
  705. lambda: cg.load_import_from(utils.__name__, "object_new")
  706. )
  707. assert var.mutation_type.cls_source is not None
  708. cg(var.mutation_type.cls_source)
  709. # Generate the args to the __new__ method
  710. for arg in var.init_args: # type: ignore[attr-defined]
  711. cg(arg)
  712. # Call the __new__ method
  713. cg.extend_output(create_call_function(1 + len(var.init_args), False)) # type: ignore[attr-defined]
  714. cg.add_cache(var)
  715. var.source = TempLocalSource(cg.tempvars[var])
  716. for ctx, args in self.save_for_backward:
  717. cg(ctx.source)
  718. cg.load_method("save_for_backward")
  719. for arg in args:
  720. cg(arg)
  721. cg.extend_output(
  722. [
  723. *create_call_method(len(args)),
  724. create_instruction("POP_TOP"),
  725. ]
  726. )
  727. def register_hook(
  728. self,
  729. tensor: "variables.TensorVariable",
  730. hook: VariableTracker,
  731. handle: "variables.RemovableHandleVariable",
  732. name: str,
  733. ) -> None:
  734. assert tensor.is_tensor()
  735. assert isinstance(hook, variables.VariableTracker)
  736. assert (
  737. isinstance(handle, variables.RemovableHandleVariable)
  738. and handle.is_mutable()
  739. )
  740. assert hasattr(torch.Tensor, name)
  741. idx = len(self.tensor_hooks.keys())
  742. # duplicate index possible because of self.remove_hook()
  743. while idx in self.tensor_hooks:
  744. idx += 1
  745. self.tensor_hooks[idx] = (tensor, hook, handle, name)
  746. assert not handle.idx
  747. handle.idx = idx
  748. def remove_hook(self, idx: int) -> None:
  749. del self.tensor_hooks[idx]
  750. def codegen_hooks(self, cg: PyCodegen) -> None:
  751. for (
  752. tensor,
  753. hook,
  754. handle,
  755. name,
  756. ) in self.tensor_hooks.values():
  757. # Note: [On tensor.register_hook]
  758. #
  759. # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented
  760. # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries).
  761. #
  762. # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph.
  763. # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in
  764. # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able
  765. # tensors. Because a source indicates knowledge of this object outside the torch compile region, and
  766. # because we are running residuals firmly before .backward() can be run, it is sound to invoke
  767. # `register_hook` on a known tensor.
  768. #
  769. # For tensors without a source, we support a limited subset of hooks. Global functions only, and
  770. # compiled_autograd must be enabled or we will graph break.
  771. #
  772. # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the
  773. # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed
  774. # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the
  775. # stack intact.
  776. #
  777. # Dynamo Tensor Hooks Workflow:
  778. # - Functions passed to register_hook are lifted globally.
  779. # - For tensors with sources:
  780. # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to:
  781. # - Generate the tensor.
  782. # - Issue a register_hook call on the tensor, linking to the globally stored function.
  783. # - Incorporate a handle if one was established in the eager phase.
  784. # - For tensors without sources:
  785. # - We don't generate any instructions for registering a hook.
  786. # - Handles from intermediary hooks are NYI.
  787. # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it.
  788. # - We then manually insert the call function above into the graph.
  789. # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST.
  790. assert tensor.source, "Hooks on non input tensors NYI - should not get here"
  791. def gen_fn() -> None:
  792. cg(tensor)
  793. cg.extend_output([cg.create_load_attr(name)])
  794. cg.add_push_null(gen_fn)
  795. cg(hook)
  796. cg.extend_output(create_call_function(1, False))
  797. # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will
  798. # be associated with the return value of register_hook(). This consumes the top of stack.
  799. cg.add_cache(handle)
  800. def get_ca_final_callbacks_var(self) -> "variables.ListVariable":
  801. from .variables.base import ValueMutationNew
  802. if self.ca_final_callbacks_var is None:
  803. self.ca_final_callbacks_var = variables.ListVariable(
  804. [], mutation_type=ValueMutationNew()
  805. )
  806. return self.ca_final_callbacks_var
  807. def _format_side_effect_message(self, var: VariableTracker) -> str:
  808. """Format a side effect log message with user stack."""
  809. assert config.side_effect_replay_policy != "silent"
  810. locations = self.mutation_user_stacks.get(var, [])
  811. description = f"Mutating object of type {var.python_type_name()}"
  812. source_info = " (no source)"
  813. if var.source is not None:
  814. if isinstance(var.source, TempLocalSource):
  815. source_info = " (source: created in torch.compile region)"
  816. elif isinstance(var, variables.CellVariable) and var.local_name is not None:
  817. source_info = f" (source: {var.local_name})"
  818. elif isinstance(
  819. var, variables.torch_function.TorchFunctionModeStackVariable
  820. ):
  821. source_info = " (source: torch function mode stack mutation)"
  822. else:
  823. # NOTE: NotImplementedError from var.source.name is a bug and must be fixed!
  824. source_info = f" (source name: {var.source.name})"
  825. if locations:
  826. # Format and dedupe stacks using tuple representation for efficiency
  827. seen = set()
  828. unique_formatted_stacks: list[str] = []
  829. stack_above_dynamo = collapse_resume_frames(get_stack_above_dynamo())
  830. for stack in locations:
  831. # Use tuple of frame info for fast deduplication
  832. # Include position info (colno, end_lineno, end_colno) to distinguish
  833. # multiple mutations on the same line (when available in Python 3.11+)
  834. stack_tuple = tuple(
  835. (
  836. f.filename,
  837. f.lineno,
  838. f.name,
  839. f.line,
  840. getattr(f, "colno", None),
  841. getattr(f, "end_lineno", None),
  842. getattr(f, "end_colno", None),
  843. )
  844. for f in stack
  845. )
  846. if stack_tuple not in seen:
  847. seen.add(stack_tuple)
  848. stack_augmented = collapse_resume_frames(stack_above_dynamo + stack)
  849. unique_formatted_stacks.append(
  850. "".join(traceback.format_list(stack_augmented))
  851. )
  852. formatted_lines: str = "\n********\n\n".join(unique_formatted_stacks)
  853. log_str = f"{description}{source_info}\n\n{textwrap.indent(formatted_lines, ' ')}"
  854. else:
  855. log_str = (
  856. f"{description}{source_info} (unable to find user stacks for mutations)"
  857. )
  858. return log_str
  859. def codegen_update_mutated(
  860. self, cg: PyCodegen, log_side_effects: bool = False
  861. ) -> None:
  862. side_effect_messages: list[str] = []
  863. # NOTE: should only be called once per VT - only if a side effect actually gets codegen'd!
  864. def _maybe_log_side_effect(var: VariableTracker) -> None:
  865. if config.side_effect_replay_policy != "silent" and log_side_effects:
  866. msg = self._format_side_effect_message(var)
  867. side_effect_messages.append(msg)
  868. # Log individual side effects for granular debugging
  869. side_effects_log.debug(msg)
  870. suffixes = []
  871. for var in self._get_modified_vars():
  872. # When replay_side_effects=False, only update variables with TempLocalSource
  873. if not config.replay_side_effects and not isinstance(
  874. var.source, TempLocalSource
  875. ):
  876. continue
  877. if isinstance(var, variables.ListVariable):
  878. # old[:] = new
  879. cg(var, allow_cache=False) # Don't codegen via source
  880. cg(var.source) # type: ignore[attr-defined]
  881. cg.extend_output(
  882. [
  883. cg.create_load_const(None),
  884. cg.create_load_const(None),
  885. create_instruction("BUILD_SLICE", arg=2),
  886. ]
  887. )
  888. suffixes.append([create_instruction("STORE_SUBSCR")])
  889. _maybe_log_side_effect(var)
  890. elif isinstance(var, variables.lists.DequeVariable):
  891. # For limited maxlen, the order of operations matter for side
  892. # effect, but we currently don't track the order, so no support.
  893. if not var.maxlen.is_constant_none():
  894. unimplemented(
  895. gb_type="Side effect on existing deque with limited maxlen",
  896. context="",
  897. explanation="This is not supported.",
  898. hints=[
  899. "Don't use a deque with `maxlen` specified.",
  900. ],
  901. )
  902. # old.extend(new), this runs last
  903. cg(var.source)
  904. cg.load_method("extend")
  905. cg(var, allow_cache=False) # Don't codegen via source
  906. suffixes.append(
  907. [
  908. *create_call_method(1),
  909. create_instruction("POP_TOP"),
  910. ]
  911. )
  912. # old.clear(), this runs first
  913. cg(var.source)
  914. cg.load_method("clear")
  915. suffixes.append(
  916. [
  917. *create_call_method(0),
  918. create_instruction("POP_TOP"),
  919. ]
  920. )
  921. _maybe_log_side_effect(var)
  922. elif isinstance(var, variables.ConstDictVariable):
  923. # Reconstruct works as follow:
  924. # (1) Skip codegen if there are no new items
  925. # (2) codegen(...) each pair of key/value
  926. # (3) create a new dictionary with the pairs of key/values above
  927. # (4) clear the original dictionary
  928. # + only if a key was removed from the input dict
  929. # (5) update the original dictionary with the dict created in (2)
  930. if var.has_new_items():
  931. cg(var.source) # type: ignore[attr-defined]
  932. cg.load_method("update")
  933. cg(var, allow_cache=False) # Don't codegen via source
  934. if var.should_reconstruct_all:
  935. cg(var.source) # type: ignore[attr-defined]
  936. cg.load_method("clear")
  937. suffixes.append(
  938. [
  939. *create_call_method(1), # update
  940. create_instruction("POP_TOP"),
  941. ]
  942. )
  943. if var.should_reconstruct_all:
  944. # clear will appear before "update" as the suffixes are
  945. # applied in reverse order.
  946. suffixes.append(
  947. [
  948. *create_call_method(0), # clear
  949. create_instruction("POP_TOP"),
  950. ]
  951. )
  952. _maybe_log_side_effect(var)
  953. elif isinstance(
  954. var, variables.torch_function.TorchFunctionModeStackVariable
  955. ):
  956. # Needed in the finally block for stack restoration
  957. cg.add_push_null(
  958. lambda: cg.load_import_from(
  959. utils.__name__, "get_torch_function_mode_stack"
  960. )
  961. )
  962. cg.call_function(0, False)
  963. name = variables.torch_function.get_prev_stack_var_name()
  964. cg.code_options["co_varnames"] += (name,)
  965. cg.append_output(create_instruction("STORE_FAST", argval=name))
  966. cg.add_push_null(
  967. lambda: cg.load_import_from(
  968. utils.__name__, "set_torch_function_mode_stack"
  969. )
  970. )
  971. cg.foreach(var.symbolic_stack)
  972. cg.append_output(
  973. create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
  974. )
  975. cg.call_function(1, False)
  976. cg.append_output(create_instruction("POP_TOP"))
  977. _maybe_log_side_effect(var)
  978. elif isinstance(var, variables.CellVariable) and var.local_name is not None:
  979. # Emit more readable and performant bytecode.
  980. # TODO generalize this for cells created during inlining.
  981. if var in self.store_attr_mutations:
  982. contents_var = self.load_cell(var)
  983. cg(contents_var)
  984. suffixes.append([cg.create_store_deref(var.local_name)])
  985. _maybe_log_side_effect(var)
  986. elif self.is_attribute_mutation(var):
  987. if isinstance(
  988. var,
  989. variables.UserDefinedDictVariable,
  990. ) and self.is_modified(var._dict_vt):
  991. # Do dict related update manually here. The store_attr
  992. # mutations will be applied later.
  993. varname_map = {}
  994. for name in _manual_dict_setitem.__code__.co_varnames:
  995. varname_map[name] = cg.tx.output.new_var()
  996. try:
  997. mro_index = type(var.value).__mro__.index(
  998. collections.OrderedDict
  999. )
  1000. except ValueError:
  1001. mro_index = type(var.value).__mro__.index(dict)
  1002. cg.extend_output(
  1003. [
  1004. create_instruction("LOAD_CONST", argval=mro_index),
  1005. create_instruction(
  1006. "STORE_FAST", argval=varname_map["mro_index"]
  1007. ),
  1008. ]
  1009. )
  1010. cg(var.source) # type: ignore[attr-defined]
  1011. cg.extend_output(
  1012. [
  1013. create_instruction(
  1014. "STORE_FAST", argval=varname_map["dict_to"]
  1015. )
  1016. ]
  1017. )
  1018. cg(var._dict_vt, allow_cache=False) # Don't codegen via source
  1019. cg.extend_output(
  1020. [
  1021. create_instruction(
  1022. "STORE_FAST", argval=varname_map["dict_from"]
  1023. )
  1024. ]
  1025. )
  1026. dict_update_insts = bytecode_from_template(
  1027. _manual_dict_setitem, varname_map=varname_map
  1028. )
  1029. suffixes.append(
  1030. [
  1031. *dict_update_insts,
  1032. create_instruction("POP_TOP"),
  1033. ]
  1034. )
  1035. _maybe_log_side_effect(var._dict_vt)
  1036. elif isinstance(
  1037. var,
  1038. variables.UserDefinedListVariable,
  1039. ) and self.is_modified(var._list_vt):
  1040. # Update the list to the updated items. Be careful in
  1041. # calling the list methods and not the overridden methods.
  1042. varname_map = {}
  1043. for name in _manual_list_update.__code__.co_varnames:
  1044. varname_map[name] = cg.tx.output.new_var()
  1045. cg(var.source) # type: ignore[attr-defined]
  1046. cg.extend_output(
  1047. [
  1048. create_instruction(
  1049. "STORE_FAST", argval=varname_map["list_to"]
  1050. )
  1051. ]
  1052. )
  1053. cg(var._list_vt, allow_cache=False) # Don't codegen via source
  1054. cg.extend_output(
  1055. [
  1056. create_instruction(
  1057. "STORE_FAST", argval=varname_map["list_from"]
  1058. )
  1059. ]
  1060. )
  1061. list_update_insts = bytecode_from_template(
  1062. _manual_list_update, varname_map=varname_map
  1063. )
  1064. suffixes.append(
  1065. [
  1066. *list_update_insts,
  1067. create_instruction("POP_TOP"),
  1068. ]
  1069. )
  1070. _maybe_log_side_effect(var._list_vt)
  1071. # Applying mutations involves two steps: 1) Push all
  1072. # reconstructed objects onto the stack. 2) Call STORE_ATTR to
  1073. # apply the mutations.
  1074. #
  1075. # Dynamo must ensure that mutations are applied in the same
  1076. # order as in the original program. Therefore, two reverse
  1077. # operations occur below.
  1078. #
  1079. # The first reverse operation concerns `suffixes`. We apply
  1080. # suffixes in reverse order due to the way Python handles the
  1081. # stack. In Step 1, we push all reconstructed objects onto the
  1082. # stack, but the item at the top of the stack refers to the last
  1083. # attribute in the mutation order. If not fixed, this will apply
  1084. # the mutations of attributes in the reverse order. To account
  1085. # for this reversal, we iterate through the mutable attributes
  1086. # in reverse order.
  1087. side_effect_occurred = False
  1088. for name, value in reversed(
  1089. self.store_attr_mutations.get(var, {}).items()
  1090. ):
  1091. if isinstance(var, variables.NewGlobalVariable):
  1092. cg.tx.output.update_co_names(name)
  1093. cg(value)
  1094. assert isinstance(var.source, GlobalSource) # type: ignore[attr-defined]
  1095. suffixes.append(
  1096. [create_instruction("STORE_GLOBAL", argval=name)]
  1097. )
  1098. side_effect_occurred = True
  1099. elif isinstance(value, variables.DeletedVariable):
  1100. if isinstance(
  1101. var.mutation_type, AttributeMutationExisting
  1102. ) and hasattr(getattr(var, "value", None), name):
  1103. cg.tx.output.update_co_names(name)
  1104. cg(var.source)
  1105. suffixes.append(
  1106. [create_instruction("DELETE_ATTR", argval=name)]
  1107. )
  1108. side_effect_occurred = True
  1109. elif isinstance(
  1110. var, variables.UserDefinedObjectVariable
  1111. ) and var.should_skip_descriptor_setter(name):
  1112. cg.add_push_null(
  1113. lambda: cg.load_import_from(
  1114. utils.__name__, "object_setattr_ignore_descriptor"
  1115. )
  1116. )
  1117. cg(var.source) # type: ignore[attr-defined]
  1118. cg(variables.ConstantVariable(name))
  1119. cg(value)
  1120. suffixes.append(
  1121. [
  1122. *create_call_function(3, False),
  1123. create_instruction("POP_TOP"),
  1124. ]
  1125. )
  1126. side_effect_occurred = True
  1127. elif (
  1128. isinstance(var, variables.UserDefinedObjectVariable)
  1129. and var.needs_slow_setattr()
  1130. ):
  1131. # __setattr__ is defined on this object, so call object.__setattr__ directly
  1132. cg.load_import_from("builtins", "object")
  1133. cg.load_method("__setattr__")
  1134. cg(var.source) # type: ignore[attr-defined]
  1135. cg(variables.ConstantVariable(name))
  1136. cg(value)
  1137. suffixes.append(
  1138. [*create_call_method(3), create_instruction("POP_TOP")]
  1139. )
  1140. side_effect_occurred = True
  1141. else:
  1142. cg.tx.output.update_co_names(name)
  1143. cg(value)
  1144. cg(var)
  1145. suffixes.append([create_instruction("STORE_ATTR", argval=name)])
  1146. side_effect_occurred = True
  1147. if side_effect_occurred:
  1148. _maybe_log_side_effect(var)
  1149. elif isinstance(var, variables.ListIteratorVariable):
  1150. for _ in range(var.index):
  1151. cg.add_push_null(
  1152. lambda: cg.load_import_from(utils.__name__, "iter_next")
  1153. )
  1154. cg(var.source) # type: ignore[attr-defined]
  1155. cg.call_function(1, False)
  1156. cg.pop_top()
  1157. _maybe_log_side_effect(var)
  1158. elif isinstance(var, variables.RandomVariable):
  1159. # set correct random seed state
  1160. def gen_fn() -> None:
  1161. cg(var.source) # type: ignore[attr-defined]
  1162. cg.load_attr("setstate")
  1163. cg.add_push_null(gen_fn)
  1164. cg(var.wrap_state(var.random.getstate()))
  1165. suffixes.append(
  1166. [
  1167. *create_call_function(1, False), # setstate
  1168. create_instruction("POP_TOP"),
  1169. ]
  1170. )
  1171. _maybe_log_side_effect(var)
  1172. else:
  1173. raise AssertionError(type(var))
  1174. # do all the actual mutations at the very end to handle dependencies
  1175. for suffix in reversed(suffixes):
  1176. cg.extend_output(suffix)
  1177. # Send batched structured trace for all side effects in this compilation
  1178. if log_side_effects and side_effect_messages:
  1179. combined_msg = "\n\n========================================\n\n".join(
  1180. side_effect_messages
  1181. )
  1182. torch._logging.trace_structured(
  1183. "artifact",
  1184. metadata_fn=lambda: {
  1185. "name": "dynamo_side_effects",
  1186. "encoding": "string",
  1187. },
  1188. payload_fn=lambda: combined_msg,
  1189. )
  1190. def is_empty(self) -> bool:
  1191. return not (
  1192. any(map(self.is_modified, self.id_to_variable.values()))
  1193. or self.tensor_hooks
  1194. or self.save_for_backward
  1195. or self.tensor_hooks
  1196. )
  1197. def clear(self) -> None:
  1198. self.keepalive.clear()
  1199. self.id_to_variable.clear()
  1200. @contextlib.contextmanager
  1201. def allow_side_effects_in_hop(
  1202. tx: "InstructionTranslatorBase",
  1203. ) -> Generator[None, None, None]:
  1204. """Context manager to temporarily allow side effects with extra outputs.
  1205. This is used for special cases (like FSDP functions) that need to perform
  1206. side effects even when the general policy is to disallow them.
  1207. """
  1208. orig_val = tx.output.current_tracer.allow_side_effects_in_hop
  1209. try:
  1210. tx.output.current_tracer.allow_side_effects_in_hop = True
  1211. yield
  1212. finally:
  1213. tx.output.current_tracer.allow_side_effects_in_hop = orig_val
  1214. @contextlib.contextmanager
  1215. def allow_externally_visible_side_effects_in_subtracer(
  1216. tx: "InstructionTranslatorBase",
  1217. ) -> Generator[None, None, None]:
  1218. orig_val = tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
  1219. try:
  1220. tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = True
  1221. tx.output.current_tracer.traced_with_externally_visible_side_effects = True
  1222. yield
  1223. finally:
  1224. tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = orig_val
  1225. @contextlib.contextmanager
  1226. def disallow_side_effects_in_generator(
  1227. tx: "InstructionTranslatorBase",
  1228. ) -> Generator[None, None, None]:
  1229. orig_val = tx.output.current_tracer.is_reconstructing_generator
  1230. try:
  1231. tx.output.current_tracer.is_reconstructing_generator = True
  1232. yield
  1233. finally:
  1234. tx.output.current_tracer.is_reconstructing_generator = orig_val