codegen.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. """
  2. This module provides utilities for generating Python bytecode in PyTorch's Dynamo system.
  3. It includes functionality for:
  4. - Constructing bytecode sequences for Python operations
  5. - Managing stack operations and variable tracking
  6. - Handling graph outputs and their conversions
  7. - Supporting different Python versions (3.11+, 3.12+, 3.13+)
  8. - Converting high-level operations to low-level bytecode instructions
  9. - Managing constant loading and attribute access
  10. - Supporting function creation and closure handling
  11. """
  12. import collections
  13. import dataclasses
  14. import re
  15. import sys
  16. import types
  17. from collections import Counter, deque
  18. from collections.abc import Callable, Iterable
  19. from typing import Any, Optional, TYPE_CHECKING, Union
  20. import torch.nn
  21. from torch.utils._ordered_set import OrderedSet
  22. from . import config, graph_break_hints, utils
  23. from .bytecode_transformation import (
  24. add_push_null,
  25. add_push_null_call_function_ex,
  26. create_binary_subscr,
  27. create_build_tuple,
  28. create_call_function,
  29. create_call_function_ex,
  30. create_call_method,
  31. create_dup_top,
  32. create_instruction,
  33. create_load_const,
  34. create_load_method,
  35. create_rot_n,
  36. Instruction,
  37. )
  38. from .exc import unimplemented
  39. from .source import AttrSource, ChainedSource, DictGetItemSource, Source
  40. from .utils import is_safe_constant, rot_n_helper
  41. from .variables.base import ValueMutationExisting, VariableTracker
  42. from .variables.functions import (
  43. ContextlibContextManagerLocalGeneratorObjectVariable,
  44. LocalGeneratorObjectVariable,
  45. )
  46. from .variables.nn_module import NNModuleVariable
  47. from .variables.tensor import (
  48. NumpyNdarrayVariable,
  49. SymNodeVariable,
  50. TensorVariable,
  51. UnspecializedPythonVariable,
  52. )
  53. from .variables.torch_function import TensorWithTFOverrideVariable
  54. if TYPE_CHECKING:
  55. from torch._dynamo.variables.builder import GraphArg
  56. from .symbolic_convert import InstructionTranslatorBase
  57. @dataclasses.dataclass
  58. class GraphOutputEntry:
  59. index: int
  60. variable: VariableTracker
  61. class PyCodegen:
  62. """
  63. Helper class uses for constructing Python bytecode
  64. """
  65. def __init__(
  66. self,
  67. tx: "InstructionTranslatorBase",
  68. root: Optional[torch.nn.Module] = None,
  69. graph_output_var: Optional[str] = None,
  70. tempvars: Optional[dict[Union[VariableTracker, Source], Any]] = None,
  71. overridden_sources: Optional[dict[Source, Source]] = None,
  72. ) -> None:
  73. self.root = root
  74. self.top_of_stack: Optional[Union[VariableTracker, Source]] = None
  75. self.uses: Counter[Union[VariableTracker, Source]] = collections.Counter()
  76. self.graph_outputs: dict[int, GraphOutputEntry] = {}
  77. self._output: list[Instruction] = []
  78. # This determines which VariableTracker/Source should be stored as
  79. # locals, and maps the VariableTracker/Source to the local variable
  80. # name. Note that it could map to None initially, in which case we'll
  81. # overwrite it to map to real temporary names via `add_cache`.
  82. self.tempvars: dict[Union[VariableTracker, Source], Any] = tempvars or {}
  83. self.tx = tx
  84. self.graph_output_var = graph_output_var
  85. self.code_options = self.tx.output.code_options
  86. self.cell_and_freevars = self.tx.cell_and_freevars
  87. self.new_var = self.tx.output.new_var
  88. self.value_from_source: bool = True
  89. # This serves as a way for codegen to use a different source; we need
  90. # this because sometimes we can't easily modify the original source
  91. # without affecting other components, e.g., guards.
  92. self.overridden_sources: dict[Source, Source] = overridden_sources or {}
  93. def restore_stack(
  94. self, stack_values: list[Any], *, value_from_source: bool = True
  95. ) -> None:
  96. prev = self.value_from_source
  97. self.value_from_source &= value_from_source
  98. try:
  99. self.foreach(stack_values)
  100. finally:
  101. self.value_from_source = prev
  102. def graph_output_vars(self) -> list[VariableTracker]:
  103. return [x.variable for x in self.graph_outputs.values()]
  104. def call_reconstruct(
  105. self, value: Union[VariableTracker, Source, "GraphArg"]
  106. ) -> None:
  107. res = value.reconstruct(self)
  108. assert res is None, f"reconstruct!=None {value}"
  109. def add_push_null(
  110. self, gen_fn: Callable[[], None], call_function_ex: bool = False
  111. ) -> None:
  112. """
  113. `gen_fn` generates instructions via PyCodegen methods
  114. that push a single callable to the stack.
  115. `add_push_null` pushes a NULL to the stack before or after the
  116. instructions generated by `gen_fn`, depending on Python version.
  117. Will attempt to use the NULL push bit for instructions
  118. with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
  119. """
  120. old_len = len(self._output)
  121. if sys.version_info < (3, 13):
  122. # gen_fn may DUP_TOP instead if TOS is not cleared.
  123. # Will cause problems since NULL will be pushed right
  124. # before the generated instructions in <= 3.12
  125. self.clear_tos()
  126. gen_fn()
  127. # inplace modify self._output
  128. added_insts = self._output[old_len:]
  129. del self._output[old_len:]
  130. if call_function_ex:
  131. self._output.extend(add_push_null_call_function_ex(added_insts))
  132. else:
  133. self._output.extend(add_push_null(added_insts))
  134. if sys.version_info >= (3, 13):
  135. # NULL will be at top of stack
  136. self.clear_tos()
  137. def __call__(
  138. self, value: Union[VariableTracker, Source, None], allow_cache: bool = True
  139. ) -> None:
  140. """
  141. Generate code such that top-of-stack (TOS) is set to value.
  142. `allow_cache` controls the behavior in the following manner. `value` can
  143. either be a VariableTracker or a Source.
  144. If `value` is a `Source`, `allow_cache` must be True (invariant asserted
  145. below). If the source was reconstructed earlier, we will reuse the
  146. generated code by loading from top of stack or tempvars.
  147. If `value` is a `VariableTracker`, we have the following cases:
  148. 1) `allow_cache=True`
  149. a) If the value.source is not None, we will emit the code based on
  150. `value.source` to handle aliasing.
  151. b) If value.source is None (example reconstructing a local list
  152. returned by the compiled function), we will reconstruct the variable
  153. tracker (w/o any source) to emit bytecode that generates a new
  154. python object.
  155. In both cases of value.source being None or not, if the value was
  156. reconstructed earlier, we will reuse the generated code by loading from
  157. top of stack or tempvars.
  158. 2) `allow_cache=False` - This is a special case (allow_cache defaults to
  159. True).
  160. a) If the value.source is not None, we reconstruct the variable
  161. tracker and emit a new python object. You might wonder what about
  162. aliasing? The place where we use this config also has the followup
  163. code where the original python object is assigned to this new python
  164. value to handle aliasing (check side_effects.py and search for
  165. allow_cache=False).
  166. b) If value.source is None, this is not allowed
  167. Notable effects:
  168. 1. `self.top_of_stack` will be set to `value`, if we don't codegen
  169. `value` based on source.
  170. 2. `self.uses[value]` will increment, unless (a). we codegen via
  171. `top_of_stack` or cached `tempvars`, or (b). `value` has special VT
  172. types like `NNModuleVariable`, etc.
  173. """
  174. assert value is not None
  175. if isinstance(value, Source):
  176. # If the source needs to be overridden, use the new one.
  177. source = self.overridden_sources.get(value, value)
  178. assert allow_cache is True, "allow_cache must be True for Source"
  179. if self.top_of_stack is value:
  180. self._output.append(create_dup_top())
  181. return
  182. if self.tempvars.get(source) is not None:
  183. self._output.append(self.create_load(self.tempvars[source]))
  184. self.top_of_stack = source
  185. return
  186. self.uses[source] += 1
  187. try:
  188. self.call_reconstruct(source)
  189. except NotImplementedError:
  190. unimplemented(
  191. gb_type="Reconstruction failure: source.reconstruct not implemented",
  192. context=str(source),
  193. explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.",
  194. hints=[*graph_break_hints.DYNAMO_BUG],
  195. )
  196. if source in self.tempvars:
  197. self._output.append(create_dup_top())
  198. self.add_cache(source)
  199. self.top_of_stack = source
  200. return
  201. assert isinstance(value, VariableTracker)
  202. output = self._output
  203. graph_outputs = self.graph_outputs
  204. if allow_cache:
  205. if self.top_of_stack is value:
  206. output.append(create_dup_top())
  207. return
  208. if self.tempvars.get(value) is not None:
  209. output.append(self.create_load(self.tempvars[value]))
  210. self.top_of_stack = value
  211. return
  212. if value.is_realized() and isinstance(
  213. value, ContextlibContextManagerLocalGeneratorObjectVariable
  214. ):
  215. unimplemented(
  216. gb_type="reconstructing @contextmanager object",
  217. context=f"object: {value}",
  218. explanation="Returning a @contextmanager object from a compiled function is not supported.",
  219. hints=[
  220. *graph_break_hints.SUPPORTABLE,
  221. ],
  222. )
  223. # Dynamo normally prefers codegen from source to account for aliasing.
  224. if (
  225. value.source is not None
  226. and allow_cache
  227. and not (
  228. value.is_realized() and isinstance(value, LocalGeneratorObjectVariable)
  229. )
  230. ):
  231. # There's a corner case for export: for instance, if the computation
  232. # graph is just identity on an input tensor, Dynamo would just emit
  233. # a `LOAD_FAST` from the input source, rather than generating an
  234. # identity FX graph.
  235. #
  236. # However, export wants to maximize graph capture; in the case
  237. # above, export _wants to_ obtain an identity FX graph (despite it
  238. # appears unnecessarily expensive for `torch.compile`), so we have
  239. # the following option to override Dynamo's preference for codegen
  240. # from source. Moreover, this option applies recursively, for cases
  241. # like input tensor being returned in a new dictionary.
  242. #
  243. # And why the `ValueMutationExisting` check? Not sure, so leaving it
  244. # to keep the old behavior, as when `value_from_source` was
  245. # introduced. TODO sort out the invariants among side effect,
  246. # codegen and export.
  247. if (
  248. isinstance(value.mutation_type, ValueMutationExisting)
  249. or self.value_from_source
  250. ):
  251. return self(value.source)
  252. if value.is_python_constant() and is_safe_constant(value.as_python_constant()):
  253. output.append(self.create_load_const(value.as_python_constant()))
  254. elif isinstance(value, TensorWithTFOverrideVariable):
  255. graph_outputs_key = self.add_graph_output(value)
  256. self.add_push_null(
  257. lambda: self.load_import_from(utils.__name__, "to_subclass")
  258. )
  259. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  260. output.append(
  261. self.create_load_global(
  262. value.global_mangled_class_name(self.tx), # type: ignore[arg-type]
  263. add=True,
  264. )
  265. )
  266. output.extend(create_call_function(2, False))
  267. elif (
  268. isinstance(value, SymNodeVariable)
  269. and value.python_type() is float
  270. and not self.tx.export
  271. ):
  272. # This is a little unusual; force the output convention to be a
  273. # Tensor here. Don't do this for export because this is
  274. # apparently load bearing for export tests (but I am a bit
  275. # doubtful it actually works in the real world)
  276. # NB: It works to add_graph_output on a computed expression
  277. # as_tensor here, because we memoize as_tensor calls on
  278. # SymNodeVariable!
  279. graph_outputs_key = self.add_graph_output(
  280. value.as_tensor(self.tx, torch.float64)
  281. )
  282. def gen_fn() -> None:
  283. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  284. output.append(self.create_load_attr("item"))
  285. self.add_push_null(gen_fn)
  286. output.extend(create_call_function(0, False))
  287. elif isinstance(
  288. value,
  289. (
  290. TensorVariable,
  291. SymNodeVariable,
  292. UnspecializedPythonVariable,
  293. NumpyNdarrayVariable,
  294. ),
  295. ):
  296. graph_outputs_key = self.add_graph_output(value)
  297. if isinstance(value, NumpyNdarrayVariable):
  298. self.add_push_null(
  299. lambda: self.load_import_from(utils.__name__, "to_numpy_helper")
  300. )
  301. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  302. output.extend(create_call_function(1, False))
  303. elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
  304. def gen_fn() -> None:
  305. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  306. output.append(self.create_load_attr("item"))
  307. self.add_push_null(gen_fn)
  308. output.extend(create_call_function(0, False))
  309. else:
  310. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  311. elif isinstance(value, NNModuleVariable):
  312. parts = value.module_key.split(".")
  313. if parts[0] in self.code_options["co_varnames"]:
  314. output.append(self.create_load(parts[0]))
  315. parts = parts[1:]
  316. else:
  317. assert self.root is not None
  318. output.append(self.create_load_const_unchecked(self.root))
  319. for part in parts:
  320. output.append(self.create_load_attr(part))
  321. else:
  322. self.uses[value] += 1
  323. try:
  324. self.call_reconstruct(value)
  325. except NotImplementedError as e:
  326. unimplemented(
  327. gb_type="Reconstruction failure",
  328. context=str(value),
  329. explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.",
  330. hints=[
  331. "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable "
  332. "that Dynamo cannot reconstruct, then remove it from the return statement.",
  333. *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
  334. "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have "
  335. "reconstruction rules may be fundamentally unreconstructable.",
  336. ],
  337. from_exc=e,
  338. )
  339. if allow_cache and value in self.tempvars:
  340. self._output.append(create_dup_top())
  341. self.add_cache(value)
  342. self.top_of_stack = value
  343. def add_graph_output(self, value: VariableTracker) -> int:
  344. graph_outputs_key = id(value.as_proxy())
  345. if graph_outputs_key not in self.graph_outputs:
  346. self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
  347. len(self.graph_outputs), value
  348. )
  349. return graph_outputs_key
  350. def load_graph_output(self, index: int) -> None:
  351. output = self._output
  352. assert self.graph_output_var is not None
  353. output.append(self.create_load(self.graph_output_var))
  354. output.append(self.create_load_const(index))
  355. output.append(self.create_binary_subscr())
  356. def add_cache(self, value: Union[VariableTracker, Source]) -> None:
  357. var = self.new_var()
  358. self.tempvars[value] = var
  359. self._output.append(self.create_store(var))
  360. def foreach(self, items: Iterable[Union[VariableTracker, Source]]) -> None:
  361. for i in items:
  362. self(i)
  363. def create_binary_subscr(self) -> Instruction:
  364. return create_binary_subscr()
  365. def setup_globally_cached(self, name: str, value: Any) -> list[Instruction]:
  366. """Store value in a new global"""
  367. name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
  368. f_globals = self.tx.f_globals
  369. if name in f_globals:
  370. assert id(f_globals[name]) == id(value)
  371. else:
  372. f_globals[name] = value
  373. return [self.create_load_global(name, add=True)]
  374. def clear_tos(self) -> None:
  375. self.top_of_stack = None
  376. def append_output(self, inst: Instruction) -> None:
  377. assert isinstance(inst, Instruction)
  378. self._output.append(inst)
  379. self.clear_tos()
  380. def extend_output(self, insts: list[Instruction]) -> None:
  381. assert all(isinstance(x, Instruction) for x in insts)
  382. self._output.extend(insts)
  383. self.clear_tos()
  384. def get_instructions(self) -> list[Instruction]:
  385. return self._output
  386. def create_load(self, name: str) -> Instruction:
  387. assert name in self.code_options["co_varnames"], f"{name} missing"
  388. return create_instruction("LOAD_FAST", argval=name)
  389. def create_load_closure(self, name: str) -> Instruction:
  390. assert name in self.cell_and_freevars()
  391. inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
  392. return create_instruction(inst_name, argval=name)
  393. def create_load_deref(self, name: str) -> Instruction:
  394. assert name in self.cell_and_freevars()
  395. return create_instruction("LOAD_DEREF", argval=name)
  396. def create_store(self, name: str) -> Instruction:
  397. assert name in self.code_options["co_varnames"], f"{name} missing"
  398. return create_instruction("STORE_FAST", argval=name)
  399. def create_store_deref(self, name: str) -> Instruction:
  400. assert name in self.cell_and_freevars()
  401. return create_instruction("STORE_DEREF", argval=name)
  402. def create_load_global(self, name: str, add: bool = False) -> Instruction:
  403. if add:
  404. self.tx.output.update_co_names(name)
  405. assert name in self.code_options["co_names"], f"{name} not in co_names"
  406. return create_instruction("LOAD_GLOBAL", argval=name)
  407. def create_load_const(self, value: Any) -> Instruction:
  408. return create_load_const(value)
  409. def create_load_const_unchecked(self, value: Any) -> Instruction:
  410. return create_load_const(value, checked=False)
  411. def load_method(self, name: str) -> None:
  412. self.tx.output.update_co_names(name)
  413. self.append_output(create_load_method(name))
  414. def call_method(self, nargs: int) -> None:
  415. self.extend_output(create_call_method(nargs))
  416. def create_load_attr(self, name: str) -> Instruction:
  417. if name not in self.code_options["co_names"]:
  418. self.code_options["co_names"] += (name,)
  419. return create_instruction("LOAD_ATTR", argval=name)
  420. def load_attr(self, name: str) -> None:
  421. self.append_output(self.create_load_attr(name))
  422. def create_load_attrs(self, names: str) -> list[Instruction]:
  423. return [self.create_load_attr(name) for name in names.split(".")]
  424. def create_store_attr(self, name: str) -> Instruction:
  425. if name not in self.code_options["co_names"]:
  426. self.code_options["co_names"] += (name,)
  427. return create_instruction("STORE_ATTR", argval=name)
  428. def store_attr(self, name: str) -> None:
  429. self.append_output(self.create_store_attr(name))
  430. def load_function_name(
  431. self, fn_name: str, push_null: bool, num_on_stack: int = 0
  432. ) -> list[Instruction]:
  433. """Load the global fn_name on the stack num_on_stack down"""
  434. output = []
  435. if push_null and sys.version_info >= (3, 11):
  436. output.extend(add_push_null(self.create_load_global(fn_name, add=True)))
  437. if num_on_stack > 0:
  438. output.extend(
  439. [
  440. *self.rot_n(num_on_stack + 2),
  441. *self.rot_n(num_on_stack + 2),
  442. ]
  443. )
  444. else:
  445. output.extend(
  446. [
  447. self.create_load_global(fn_name, add=True),
  448. *self.rot_n(num_on_stack + 1),
  449. ]
  450. )
  451. return output
  452. def rot_n(self, n: int) -> list[Instruction]:
  453. try:
  454. return create_rot_n(n)
  455. except AttributeError:
  456. # desired rotate bytecode doesn't exist, generate equivalent bytecode
  457. return [
  458. create_build_tuple(n),
  459. self.create_load_const_unchecked(rot_n_helper(n)),
  460. *create_rot_n(2),
  461. *create_call_function_ex(False, False),
  462. create_instruction("UNPACK_SEQUENCE", arg=n),
  463. ]
  464. def pop_null(self) -> list[Instruction]:
  465. # POP_TOP doesn't work for null, so we pop nulls by pushing in a
  466. # nop function, calling it (which consumes the null), and popping the result.
  467. assert sys.version_info >= (3, 11)
  468. return [
  469. self.create_load_const_unchecked(lambda: None),
  470. # 3.13 swapped NULL and callable
  471. *(
  472. (create_instruction("SWAP", arg=2),)
  473. if sys.version_info >= (3, 13)
  474. else ()
  475. ),
  476. *create_call_function(0, False),
  477. create_instruction("POP_TOP"),
  478. ]
  479. def pop_top(self) -> None:
  480. self.append_output(create_instruction("POP_TOP"))
  481. def call_function(self, nargs: int, push_null: bool) -> None:
  482. self.extend_output(create_call_function(nargs, push_null=push_null))
  483. def dup_top(self) -> None:
  484. self.append_output(create_dup_top())
  485. def store(self, varname: str) -> None:
  486. self.append_output(self.create_store(varname))
  487. def load_deref(self, varname: str) -> None:
  488. self.append_output(self.create_load_deref(varname))
  489. def make_function_with_closure(
  490. self,
  491. fn_name: str,
  492. code: types.CodeType,
  493. ) -> None:
  494. """Creates a closure with code object `code`.
  495. Expects the TOS to be the tuple of cells to use for this closure.
  496. TOS will be popped to create the closure.
  497. Args:
  498. - fn_name: name of the function
  499. - code: code object of the function
  500. (does not include the tuple of cells on the TOS)
  501. """
  502. output = self._output
  503. output.append(self.create_load_const(code))
  504. if sys.version_info < (3, 11):
  505. output.append(self.create_load_const(fn_name))
  506. if sys.version_info >= (3, 13):
  507. output.extend(
  508. [
  509. create_instruction("MAKE_FUNCTION"),
  510. create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
  511. ]
  512. )
  513. else:
  514. output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
  515. self.clear_tos()
  516. def create_load_python_module(self, mod: types.ModuleType) -> Instruction:
  517. """
  518. Generate a LOAD_GLOBAL instruction to fetch a given python module.
  519. """
  520. output = self.tx.output
  521. global_scope = output.global_scope
  522. name = re.sub(r"^.*[.]", "", mod.__name__)
  523. if global_scope.get(name, None) is mod:
  524. return self.create_load_global(name, add=True)
  525. prefix = f"___module_{name}"
  526. global_name = self.tx.output.install_global_by_id(prefix, mod)
  527. return self.create_load_global(global_name, add=True)
  528. def mark_source_temp(self, source: Source) -> None:
  529. """
  530. Mark a source as a temp variable, so that it can be reused.
  531. """
  532. if source not in self.tempvars:
  533. self.tempvars[source] = None
  534. def make_call_generated_code(self, fn_name: str) -> None:
  535. """Call the generated code function stored in fn_name"""
  536. self.extend_output(self.load_function_name(fn_name, True))
  537. graphargs = self.tx.output.graphargs
  538. def extract_nested_sources(source: Source) -> list[Source]:
  539. nested_sources: list[Source] = []
  540. if isinstance(source, ChainedSource):
  541. nested_sources.append(source.base)
  542. if isinstance(source, DictGetItemSource) and isinstance(
  543. source.index, Source
  544. ):
  545. nested_sources.append(source.index)
  546. return nested_sources
  547. def collect_temp_sources(sources: deque[Source], codegen: PyCodegen) -> None:
  548. seen_sources: OrderedSet[Source] = OrderedSet()
  549. while sources:
  550. current_source = sources.popleft()
  551. if current_source in seen_sources:
  552. # This source is used at least twice, so it can be reused
  553. codegen.mark_source_temp(current_source)
  554. # Dont trace source further. This prevents us from marking too
  555. # many nodes as temp sources.
  556. continue
  557. seen_sources.add(current_source)
  558. sources.extend(extract_nested_sources(current_source))
  559. # Collect all the sources that are used more than once, so that we can
  560. # generate tmp variables in the generated pre-graph bytecode. This
  561. # essentially implements CSE.
  562. collect_temp_sources(
  563. deque([arg.source for arg in graphargs if arg.source is not None]), self
  564. )
  565. cm_var = None
  566. if config.record_runtime_overhead:
  567. # Record the pregraph bytecode start
  568. self.add_push_null(
  569. lambda: self.load_import_from(
  570. utils.__name__, "record_pregraph_bytecode_enter"
  571. )
  572. )
  573. self.extend_output(create_call_function(0, False))
  574. cm_var = self.new_var()
  575. self.store(cm_var)
  576. for arg in graphargs:
  577. if arg.pass_arg_as_tensor:
  578. self.add_push_null(
  579. lambda: self.extend_output(
  580. [
  581. self.create_load_python_module(torch),
  582. self.create_load_attr("_as_tensor_fullprec"),
  583. ]
  584. )
  585. )
  586. self.call_reconstruct(arg)
  587. self.extend_output(create_call_function(1, False))
  588. else:
  589. self.call_reconstruct(arg)
  590. if config.record_runtime_overhead:
  591. # Record the pregraph bytecode end
  592. self.add_push_null(
  593. lambda: self.load_import_from(
  594. utils.__name__, "record_pregraph_bytecode_exit"
  595. )
  596. )
  597. assert cm_var is not None
  598. self.extend_output([self.create_load(cm_var)])
  599. self.extend_output(create_call_function(1, False))
  600. self.pop_top()
  601. self.extend_output(create_call_function(len(graphargs), False))
  602. def create_import_name(self, module_name: str) -> Instruction:
  603. return create_instruction("IMPORT_NAME", argval=module_name)
  604. def load_import_from(self, module_name: str, object_name: str) -> None:
  605. source = AttrSource(self.tx.import_source(module_name), object_name)
  606. # Note: This approach is somewhat aggressive because typically, a source is marked
  607. # as a tempvar only when it is used more than once. In this case, we're marking it
  608. # as a tempvar without performing that analysis. However, this is a simple solution,
  609. # and in many cases, load imports are reused multiple times.
  610. self.mark_source_temp(source)
  611. self(source)
  612. def create_call_function_kw(
  613. self, nargs: int, kw_names: Iterable[str], push_null: bool
  614. ) -> list[Instruction]:
  615. if sys.version_info >= (3, 13):
  616. output = create_call_function(nargs, push_null)
  617. assert output[-1].opname == "CALL"
  618. output.insert(-1, self.create_load_const(kw_names))
  619. output[-1] = create_instruction("CALL_KW", arg=nargs)
  620. return output
  621. elif sys.version_info >= (3, 11):
  622. output = create_call_function(nargs, push_null)
  623. if sys.version_info >= (3, 12):
  624. idx = -1
  625. expected_inst = "CALL"
  626. else:
  627. idx = -2
  628. expected_inst = "PRECALL"
  629. assert output[idx].opname == expected_inst
  630. kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
  631. output.insert(idx, kw_names_inst)
  632. return output
  633. return [
  634. self.create_load_const(kw_names),
  635. create_instruction("CALL_FUNCTION_KW", arg=nargs),
  636. ]
  637. def create_delete(self, value: object) -> Instruction:
  638. return create_instruction("DELETE_FAST", argval=value)