eval_frame.py 100 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593
  1. # mypy: disable-error-code="method-assign"
  2. """
  3. This module implements the core frame evaluation handler for TorchDynamo's compilation system.
  4. The eval frame handler intercepts Python bytecode execution at runtime to enable dynamic
  5. compilation and optimization of PyTorch code.
  6. Key components defined here:
  7. - Frame evaluation handlers that intercept and analyze Python execution frames
  8. - Guards management for tracking dependencies and invalidating compiled code
  9. - Optimization contexts and decorators (optimize, run_once, disable, etc.)
  10. - Export functionality for saving optimized graphs
  11. - Backend compiler integrations and callback management
  12. Functions in this file are responsible for modifying the eval frame handler at RUNTIME.
  13. Therefore, all functions in this file are hot and performance-critical. Functions that
  14. only execute at compile time should be placed in torch._dynamo.convert_frame.
  15. The eval frame handler is the core mechanism that enables TorchDynamo to dynamically
  16. intercept, analyze and optimize PyTorch code during execution. It works by registering
  17. a custom frame evaluation function that gets called for every Python frame, allowing
  18. us to detect PyTorch operations and trigger compilation as needed.
  19. """
  20. from __future__ import annotations
  21. import atexit
  22. import contextlib
  23. import functools
  24. import inspect
  25. import logging
  26. import os
  27. import sys
  28. import sysconfig
  29. import textwrap
  30. import threading
  31. import traceback
  32. import types
  33. import unittest
  34. import warnings
  35. import weakref
  36. from collections.abc import Generator, Sized
  37. from dataclasses import dataclass
  38. from enum import Enum
  39. from os.path import dirname, join
  40. from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union
  41. from unittest.mock import patch
  42. import sympy
  43. import torch
  44. import torch.fx
  45. import torch.utils._pytree as pytree
  46. import torch.utils.checkpoint
  47. from torch import _guards
  48. # see discussion at https://github.com/pytorch/pytorch/issues/120699
  49. from torch._C._dynamo.eval_frame import ( # noqa: F401
  50. _EvalFrameOverride,
  51. reset_code,
  52. set_code_exec_strategy,
  53. set_eval_frame,
  54. set_eval_frame_override,
  55. set_guard_complete_hook,
  56. set_guard_error_hook,
  57. set_skip_guard_eval_unsafe,
  58. unsupported,
  59. )
  60. from torch._dispatch.python import enable_python_dispatcher
  61. from torch._dynamo.types import ConvertFrameReturn, FrameAction, FrameExecStrategy
  62. from torch._export.utils import _compiling_state_context
  63. from torch._library.opaque_object import is_opaque_type
  64. from torch._subclasses.fake_tensor import unset_fake_temporarily
  65. from torch._utils_internal import DISABLE_JUSTKNOBS, justknobs_check, log_export_usage
  66. from torch.export.dynamic_shapes import (
  67. _combine_args,
  68. _DimHint,
  69. _DimHintType,
  70. _IntWrapper,
  71. _process_dynamic_shapes,
  72. _RelaxedConstraint,
  73. Constraint,
  74. )
  75. from torch.fx import GraphModule, traceback as fx_traceback
  76. from torch.fx.experimental._dynamism import (
  77. clone_and_convert_to_meta,
  78. track_dynamism_across_examples,
  79. )
  80. from torch.fx.experimental.proxy_tensor import make_fx
  81. from torch.fx.experimental.symbolic_shapes import (
  82. ConstraintViolationError,
  83. DimDynamic,
  84. ShapeEnv,
  85. StatelessSymbolicContext,
  86. )
  87. from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
  88. from . import config, convert_frame, distributed, external_utils, trace_rules, utils
  89. from .backends.registry import CompilerFn, lookup_backend
  90. from .code_context import code_context
  91. from .exc import (
  92. CondOpArgsMismatchError,
  93. ShortenTraceback,
  94. UncapturedHigherOrderOpError,
  95. Unsupported,
  96. UserError,
  97. UserErrorType,
  98. )
  99. from .hooks import Hooks
  100. from .mutation_guard import install_generation_tagging_init
  101. from .utils import (
  102. _get_error_on_graph_break,
  103. _set_error_on_graph_break,
  104. common_constant_types,
  105. compile_times,
  106. )
  107. if TYPE_CHECKING:
  108. from collections.abc import Callable, Iterable, Sequence
  109. from torch._dynamo.package import CompilePackage
  110. from torch._dynamo.repro.after_dynamo import WrapBackendDebug
  111. from torch._subclasses import fake_tensor
  112. from torch.fx.node import Argument, Node, Target
  113. from .types import (
  114. CacheEntry,
  115. DynamoCallback,
  116. DynamoFrameType,
  117. GuardFail,
  118. GuardFilterEntry,
  119. )
  120. log = logging.getLogger(__name__)
  121. always_optimize_code_objects = utils.ExactWeakKeyDictionary()
  122. null_context = contextlib.nullcontext
  123. # See https://github.com/python/typing/pull/240
  124. class Unset(Enum):
  125. token = 0
  126. cached_backends: dict[int, CompilerFn] = {}
  127. unset = Unset.token
  128. _in_optimized_module = False
  129. if DISABLE_JUSTKNOBS:
  130. _maybe_set_eval_frame = set_eval_frame
  131. else:
  132. def _maybe_set_eval_frame(callback: DynamoCallback) -> DynamoCallback:
  133. # A wrapper on set_eval_frame that is guarded by a Justknob.
  134. # Users can disable torchDynamo by setting the JK to False.
  135. if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
  136. torch._dynamo.utils.warn_once(
  137. "Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame"
  138. )
  139. return callback
  140. else:
  141. return set_eval_frame(callback)
  142. @dataclass
  143. class DynamoStance:
  144. stance: str = "default"
  145. skip_guard_eval_unsafe: bool = False
  146. backend: Union[str, Callable[..., Any], None] = None
  147. _stance = DynamoStance()
  148. def _set_stance(stance: DynamoStance) -> DynamoStance:
  149. global _stance
  150. from torch._C._dynamo.eval_frame import get_eval_frame_callback
  151. callback = get_eval_frame_callback()
  152. if callback is not False and callback is not None:
  153. raise RuntimeError("attempted to set_stance in a torch.compile region")
  154. prior = _stance
  155. _stance = stance
  156. return prior
  157. _set_stance._dynamo_forbidden = True # type: ignore[attr-defined]
  158. _EXAMPLE_INPUTS: Optional[dict[str, list[Any]]] = None
  159. def get_example_inputs(key: str) -> list[Any]:
  160. global _EXAMPLE_INPUTS
  161. if _EXAMPLE_INPUTS is None:
  162. _EXAMPLE_INPUTS = {}
  163. if key not in _EXAMPLE_INPUTS:
  164. _EXAMPLE_INPUTS[key] = []
  165. return _EXAMPLE_INPUTS[key]
  166. @contextlib.contextmanager
  167. def _set_in_optimized_module() -> Generator[None, None, None]:
  168. # Set in dynamo's OptimizedModule forward, to have better coverage than is_compiling().
  169. # Prevents graph-breaking forward hooks from being registered & traced.
  170. # TODO(pianpwk): subsume this flag with better is_compiling() coverage
  171. global _in_optimized_module
  172. _old_in_optimized_module = (
  173. _in_optimized_module # do we need this? can we just set it to False after
  174. )
  175. _in_optimized_module = True
  176. try:
  177. yield
  178. finally:
  179. _in_optimized_module = _old_in_optimized_module
  180. def _is_in_optimized_module() -> bool:
  181. global _in_optimized_module
  182. return _in_optimized_module
  183. def _callback_from_stance(callback: DynamoCallback) -> DynamoCallback:
  184. if _stance.stance == "default":
  185. # force_backend
  186. if _stance.backend is not None and callback not in (False, None):
  187. # pyrefly: ignore [bad-argument-type]
  188. callback = _create_wrapped_callback(get_compiler_fn(_stance.backend))
  189. return callback
  190. elif _stance.stance == "eager_then_compile":
  191. if callback not in (False, None):
  192. return _create_delayed_compile_callback(callback, _stance.stance)
  193. return callback
  194. elif _stance.stance == "aot_eager_then_compile":
  195. if callback not in (False, None):
  196. return _create_delayed_compile_callback(callback, _stance.stance)
  197. return callback
  198. elif _stance.stance == "force_eager":
  199. # disable
  200. return None
  201. elif _stance.stance == "eager_on_recompile":
  202. # run mode
  203. return False
  204. elif _stance.stance == "fail_on_recompile":
  205. if callback in (False, None):
  206. return callback
  207. def fail_callback(
  208. frame: DynamoFrameType, *args: Any, **kwargs: Any
  209. ) -> ConvertFrameReturn:
  210. if trace_rules.check(frame.f_code):
  211. return ConvertFrameReturn()
  212. if not convert_frame.has_tensor_in_frame(frame):
  213. return ConvertFrameReturn()
  214. from torch._C._dynamo.eval_frame import (
  215. _debug_get_cache_entry_list,
  216. _debug_get_precompile_entries,
  217. )
  218. from torch._dynamo.guards import get_and_maybe_log_recompilation_reasons
  219. message = (
  220. "Detected recompile when torch.compile stance is 'fail_on_recompile'. "
  221. + f"filename: '{frame.f_code.co_filename}', "
  222. + f"function name: '{frame.f_code.co_name}', "
  223. + f"line number: {frame.f_lineno}"
  224. )
  225. cache_entries = _debug_get_cache_entry_list(frame.f_code)
  226. if cache_entries:
  227. reasons = get_and_maybe_log_recompilation_reasons(
  228. cache_entries[0], frame, innermost_fn(callback), skip_logging=True
  229. )
  230. if reasons:
  231. failures = textwrap.indent("\n".join(reasons), "- ")
  232. guard_failure_details = (
  233. f"triggered by the following guard failure(s):\n{failures}"
  234. )
  235. message += f"\n{textwrap.indent(guard_failure_details, ' ')}"
  236. precompile_entries = _debug_get_precompile_entries(frame.f_code)
  237. if len(precompile_entries) > 0:
  238. message += "\nFailed on the following precompiled guards: "
  239. for entry in precompile_entries:
  240. message += f"\n{entry.guard_manager}{entry.guard_manager.check_verbose(frame.f_locals)}" # type: ignore[attr-defined]
  241. raise RuntimeError(message)
  242. # to prevent cache miss due to different backend
  243. fail_callback._torchdynamo_orig_backend = callback # type: ignore[attr-defined]
  244. return fail_callback
  245. else:
  246. raise RuntimeError(f"invalid torch.compile stance '{_stance}'")
  247. def _create_wrapped_callback(
  248. compiler_fn: CompilerFn,
  249. ) -> convert_frame.CatchErrorsWrapper:
  250. hooks = Hooks()
  251. return convert_frame.catch_errors_wrapper(
  252. convert_frame.convert_frame( # type: ignore[arg-type]
  253. compiler_fn,
  254. hooks,
  255. ),
  256. hooks,
  257. )
  258. def _get_or_add_example_inputs(frame: DynamoFrameType) -> list[Any]:
  259. key = frame.f_code.co_filename + str(frame.f_code.co_firstlineno)
  260. example_inputs = get_example_inputs(key)
  261. if len(example_inputs) < 2:
  262. example_inputs.append(clone_and_convert_to_meta(frame.f_locals))
  263. return example_inputs
  264. def _create_delayed_compile_callback(
  265. callback: DynamoCallback, stance: str
  266. ) -> Callable[..., Any]:
  267. def callback_fn(*args: Any, **kwargs: Any) -> convert_frame.ConvertFrameReturn:
  268. frame = args[0]
  269. example_inputs = _get_or_add_example_inputs(frame)
  270. if len(example_inputs) == 1:
  271. if stance == "eager_then_compile":
  272. return ConvertFrameReturn(
  273. frame_exec_strategy=FrameExecStrategy(
  274. FrameAction.DEFAULT, FrameAction.DEFAULT
  275. )
  276. )
  277. elif stance == "aot_eager_then_compile":
  278. aot_eager_fn = get_compiler_fn("aot_eager")
  279. # pyrefly: ignore [bad-argument-type]
  280. return _create_wrapped_callback(aot_eager_fn)(*args, **kwargs)
  281. dynamism = track_dynamism_across_examples(example_inputs)
  282. code_context.get_context(frame.f_code)["dynamism"] = dynamism
  283. compiler_fn = callback._torchdynamo_orig_backend._torchdynamo_orig_backend # type: ignore[union-attr]
  284. return _create_wrapped_callback(compiler_fn)(*args, **kwargs)
  285. # to prevent cache miss due to different backend
  286. callback_fn._torchdynamo_orig_backend = callback # type: ignore[attr-defined]
  287. return callback_fn
  288. def _is_skip_guard_eval_unsafe_stance() -> bool:
  289. return _stance.skip_guard_eval_unsafe
  290. def _reset_guarded_backend_cache() -> None:
  291. global cached_backends
  292. for backend in cached_backends.values():
  293. if hasattr(backend, "reset"):
  294. backend.reset()
  295. cached_backends.clear()
  296. DONT_WRAP_FILES = {
  297. # For tracing into fx modules
  298. inspect.getsourcefile(GraphModule),
  299. join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"),
  300. }
  301. def _debug_get_cache_entry_list(
  302. code: Union[types.CodeType, Callable[..., Any]],
  303. ) -> list[CacheEntry]:
  304. """
  305. Given a code object or a callable object, retrieve the cache entries
  306. stored in this code.
  307. """
  308. if callable(code):
  309. code = code.__code__
  310. return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
  311. class OptimizedModule(torch.nn.Module):
  312. """
  313. Wraps the original nn.Module object and later patches its
  314. forward method to optimized self.forward method.
  315. """
  316. _torchdynamo_orig_callable: Callable[..., Any]
  317. get_compiler_config: Callable[[], Any]
  318. _opt_mod_attributes = {
  319. "_orig_mod",
  320. "dynamo_ctx",
  321. "_torchdynamo_orig_callable",
  322. "_torchdynamo_wrapper_id",
  323. "get_compiler_config",
  324. "forward",
  325. "_forward",
  326. "__dict__",
  327. "named_children_walk",
  328. "_super_module_initialized",
  329. }
  330. def __init__(self, mod: torch.nn.Module, dynamo_ctx: _TorchDynamoContext) -> None:
  331. # NOTE: this must go first, because attribute reads/writes of `self`
  332. # uses `_orig_mod`, and sometimes users override `Module.__init__` to
  333. # do attribute reads/writes on `self`.
  334. #
  335. # We also can't use regular setattr because `super().__setattr__` will
  336. # complain for module value before `super().__init__()`
  337. object.__setattr__(self, "_orig_mod", mod)
  338. self._super_module_initialized = False
  339. super().__init__()
  340. self._super_module_initialized = True
  341. # Installs the params/buffer
  342. self._orig_mod = mod # `super().__setattr__` will register this module
  343. self.dynamo_ctx = dynamo_ctx
  344. self._initialize()
  345. self.training = self._orig_mod.training
  346. def __len__(self) -> int:
  347. # Proxy the len call to the original module
  348. # pyrefly: ignore [invalid-argument, unsafe-overlap]
  349. if isinstance(self._orig_mod, Sized):
  350. return len(self._orig_mod)
  351. # Mimic python's default behavior for objects without a length
  352. raise TypeError(f"{type(self._orig_mod).__name__} does not support len()")
  353. def _initialize(self) -> None:
  354. # Do this stuff in constructor to lower overhead slightly
  355. if isinstance(self.dynamo_ctx, DisableContext):
  356. # No need to check trace rules
  357. # pyrefly: ignore [bad-argument-type]
  358. self.forward = self.dynamo_ctx(self._orig_mod.__call__)
  359. elif config.wrap_top_frame or (
  360. isinstance(self._orig_mod.forward, types.MethodType)
  361. and (trace_rules.check(self._orig_mod.forward))
  362. ):
  363. # This may be a torch.nn.* instance in trace_rules.py which
  364. # won't trigger a frame evaluation workaround to add an extra
  365. # frame we can capture
  366. # pyrefly: ignore [bad-argument-type]
  367. self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod))
  368. else:
  369. # Invoke hooks outside of dynamo then pickup the inner frame
  370. self.forward = self.dynamo_ctx(self._orig_mod.__call__)
  371. if hasattr(self._orig_mod, "_initialize_hook"):
  372. self._forward = self.forward
  373. self.forward = self._call_lazy_check
  374. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  375. if torch.nn.modules.module._has_any_global_hook():
  376. warnings.warn(
  377. "Using `torch.compile(module)` when there are global hooks on "
  378. "modules (e.g., from `register_module_forward_hook`); this will"
  379. " cause the hooks to fire an extra time for the "
  380. "`OptimizedModule` created by `torch.compile(module)`. If this "
  381. "causes undesired behavior, please try using `module.compile()`"
  382. ", or use the per-module hooks instead",
  383. stacklevel=2,
  384. )
  385. with _set_in_optimized_module():
  386. return super().__call__(*args, **kwargs)
  387. def _aot_compile(self, inputs: list[torch._dynamo.aot_compile.ModelInput]) -> None:
  388. """
  389. Experimental: AOT Compile a set of inputs and use that as the forward function
  390. """
  391. model = self._orig_mod
  392. hooks = self.dynamo_ctx._hooks
  393. assert hooks is not None
  394. if not config.enable_aot_compile:
  395. raise RuntimeError(
  396. "AOT Compile is not enabled, please set torch._dynamo.config.enable_aot_compile=True"
  397. )
  398. if not self.dynamo_ctx.fullgraph:
  399. raise RuntimeError(
  400. "Graph breaks are not supported with aot compile. Please use torch.compile(fullgraph=True)."
  401. )
  402. if not callable(self.dynamo_ctx.callback):
  403. raise RuntimeError("aot compile requires a callable dynamo callback.")
  404. backend = innermost_backend(self.dynamo_ctx.callback)
  405. from torch._dynamo.aot_compile import aot_compile_module
  406. self.forward = aot_compile_module(model, inputs, hooks, backend)
  407. def _save_aot_compiled_module(self, path: Optional[str] = None) -> bytes:
  408. if not config.enable_aot_compile:
  409. raise RuntimeError(
  410. "AOT Compile is not enabled, please set torch._dynamo.config.enable_aot_compile=True"
  411. )
  412. from torch._dynamo.aot_compile import AOTCompiledModel
  413. assert isinstance(self.forward, AOTCompiledModel)
  414. result: bytes = self.forward.serialize()
  415. if path is not None:
  416. with open(path, "wb") as f:
  417. f.write(result)
  418. return result
  419. def _load_aot_compiled_module(self, data: bytes) -> None:
  420. if not config.enable_aot_compile:
  421. raise RuntimeError(
  422. "AOT Compile is not enabled, please set torch._dynamo.config.enable_aot_compile=True"
  423. )
  424. from torch._dynamo.aot_compile import AOTCompiledModel
  425. compiled_forward = AOTCompiledModel.deserialize(self._orig_mod, data)
  426. assert isinstance(compiled_forward, AOTCompiledModel)
  427. self.forward = compiled_forward
  428. def __reduce__(
  429. self,
  430. ) -> tuple[type[OptimizedModule], tuple[torch.nn.Module, _TorchDynamoContext]]:
  431. return (self.__class__, (self._orig_mod, self.dynamo_ctx))
  432. def __getstate__(self) -> dict[str, Any]:
  433. state = dict(self.__dict__)
  434. state.pop("forward", None)
  435. state.pop("__call__", None)
  436. return state
  437. def __setstate__(self, state: dict[str, Any]) -> None:
  438. self.__dict__ = state
  439. self._initialize()
  440. @property
  441. # pyrefly: ignore [bad-override]
  442. def training(self) -> bool:
  443. return self._orig_mod.training
  444. @training.setter
  445. def training(self, value: bool) -> None:
  446. # Ignore the `training` mutation in `super().__init__()`, since that's
  447. # setting the default on `nn.Module`, but we are mirroring the
  448. # `training` attr in `self._orig_mod`.
  449. if self._super_module_initialized:
  450. self._orig_mod.training = value
  451. def __getattr__(self, name: str) -> Any:
  452. if name == "_orig_mod":
  453. return self._modules["_orig_mod"]
  454. return getattr(self._orig_mod, name)
  455. def __setattr__(self, name: str, value: Any) -> None:
  456. # Allow patching over class attributes
  457. if hasattr(type(self), name):
  458. return super().__setattr__(name, value)
  459. if name in OptimizedModule._opt_mod_attributes:
  460. return super().__setattr__(name, value)
  461. return setattr(self._orig_mod, name, value)
  462. def __delattr__(self, name: str) -> None:
  463. # This mirrors `__setattr__`
  464. if hasattr(type(self), name):
  465. return super().__delattr__(name)
  466. if name in OptimizedModule._opt_mod_attributes:
  467. return super().__delattr__(name)
  468. return delattr(self._orig_mod, name)
  469. def _call_lazy_check(self, *args: Any, **kwargs: Any) -> Any:
  470. if (
  471. hasattr(self._orig_mod, "_initialize_hook")
  472. and hasattr(self._orig_mod, "_infer_parameters")
  473. and callable(self._orig_mod._infer_parameters)
  474. ):
  475. # In the case of a lazy module, we want to run
  476. # the pre-hooks which initialize it.
  477. # Afterwards, lazy module deletes its pre-hooks
  478. # to avoid treating it as lazy on subsequent recompile.
  479. self._orig_mod._infer_parameters(self._orig_mod, args, kwargs)
  480. return self._forward(*args, **kwargs)
  481. def __dir__(self) -> list[str]:
  482. orig_mod_attrs = self._orig_mod.__dir__()
  483. return orig_mod_attrs + [
  484. attr for attr in super().__dir__() if attr not in orig_mod_attrs
  485. ]
  486. def remove_from_cache(f: Any) -> None:
  487. """
  488. Make sure f.__code__ is not cached to force a recompile
  489. """
  490. if isinstance(f, types.CodeType):
  491. reset_code(f)
  492. elif hasattr(f, "__code__"):
  493. reset_code(f.__code__)
  494. elif hasattr(getattr(f, "forward", None), "__code__"):
  495. reset_code(f.forward.__code__)
  496. else:
  497. from . import reset # type: ignore[attr-defined]
  498. reset()
  499. log.warning("could not determine __code__ for %s", f)
  500. def nothing() -> None:
  501. pass
  502. def always_false() -> bool:
  503. return False
  504. def innermost_fn(fn: Callable[..., Any]) -> Callable[..., Any]:
  505. """
  506. In case of nesting of _TorchDynamoContext calls, find the innermost
  507. function. TorchDynamo caches on fn.__code__ object, so its necessary to find
  508. the innermost function to pass on the optimize, run, disable etc.
  509. """
  510. unaltered_fn = fn
  511. while (
  512. hasattr(unaltered_fn, "_torchdynamo_orig_callable")
  513. # Only follow the chain if _torchdynamo_wrapper_id matches id(fn).
  514. # This prevents following chains in two cases:
  515. # 1. Bound methods: id(bound_method) != id(wrapper_function), so we
  516. # won't unwrap through __func__ and lose the self binding.
  517. # 2. functools.wraps copies: When functools.wraps copies
  518. # _torchdynamo_orig_callable from a wrapped function, the copied
  519. # _torchdynamo_wrapper_id won't match the outer wrapper's id.
  520. and getattr(unaltered_fn, "_torchdynamo_wrapper_id", None) == id(unaltered_fn)
  521. ):
  522. unaltered_fn = unaltered_fn._torchdynamo_orig_callable
  523. assert callable(unaltered_fn), (
  524. f"A callable function is expected, but {type(unaltered_fn)} is provided."
  525. )
  526. return unaltered_fn
  527. def innermost_backend(fn: Callable[..., Any]) -> Callable[..., Any]:
  528. """
  529. Unwrap backend wrapper chain via _torchdynamo_orig_backend to find the
  530. innermost backend function.
  531. """
  532. while hasattr(fn, "_torchdynamo_orig_backend"):
  533. fn = fn._torchdynamo_orig_backend
  534. assert callable(fn), (
  535. f"A callable function is expected, but {type(fn)} is provided."
  536. )
  537. return fn
  538. def make_set_enable_dynamic(enable: bool) -> Any:
  539. assert isinstance(enable, bool)
  540. if enable:
  541. # Assume everything is dynamic by default
  542. return config._make_closure_patcher(assume_static_by_default=False)
  543. else:
  544. return config._make_closure_patcher(
  545. automatic_dynamic_shapes=False, assume_static_by_default=True
  546. )
  547. @contextlib.contextmanager
  548. def set_enable_dynamic(enable: bool) -> Generator[None, None, None]:
  549. cleanup = make_set_enable_dynamic(enable)()
  550. try:
  551. yield
  552. finally:
  553. cleanup()
  554. # A thread local storage that serves to store information as Dynamo traces
  555. # through a user provided function.
  556. class DynamoTLS(threading.local):
  557. # Each string is a summary of a frame Dynamo attempted to trace, stored in
  558. # temporal order.
  559. traced_frame_infos: list[str] = []
  560. dynamo_tls = DynamoTLS()
  561. def clear_dynamo_tls() -> None:
  562. dynamo_tls.traced_frame_infos.clear()
  563. @atexit.register
  564. def _log_traced_frames() -> None:
  565. """
  566. At program exit, log all of the frames Dynamo has attempted to trace from,
  567. excluding the continuation frames generated by Dynamo.
  568. """
  569. msg = "\n".join(dynamo_tls.traced_frame_infos)
  570. msg = textwrap.indent(msg, " * ")
  571. msg = f"TorchDynamo attempted to trace the following frames: [\n{msg}\n]"
  572. log.info(msg)
  573. def guard_collectives_hook(guard_eval_result: bool) -> bool:
  574. import torch.distributed as dist
  575. from torch._dynamo.utils import dynamo_timed
  576. # guard_eval_result == True ==> cache hit
  577. if pg := distributed.get_guard_pg():
  578. with dynamo_timed(
  579. "guard_collective", log_pt2_compile_event=False, log_waitcounter=True
  580. ):
  581. log.debug("guard_collective %s", guard_eval_result)
  582. # TODO: a bit awkward to time, this isn't inside of the dynamo compile region
  583. all_results = [None] * pg.size()
  584. dist.all_gather_object(all_results, guard_eval_result, group=pg)
  585. # True = everyone hit, OK to run
  586. # False = someone missed, force recompile everywhere
  587. res = all(all_results)
  588. log.debug("guard_collective %s -> %s", guard_eval_result, res)
  589. return res
  590. return guard_eval_result
  591. _not_set = object()
  592. def _get_eval_frame_override() -> _EvalFrameOverride:
  593. if torch._dynamo.config.error_on_dynamo_callback_in_fullgraph_compiled_code:
  594. return _EvalFrameOverride.ERROR
  595. return _EvalFrameOverride.SKIP
  596. class _TorchDynamoContext:
  597. def __init__(
  598. self,
  599. callback: DynamoCallback,
  600. on_enter: Callable[[], Any] = nothing,
  601. backend_ctx_ctor: Callable[
  602. [], contextlib.AbstractContextManager[Any]
  603. ] = null_context,
  604. patch_fn: Callable[[], Any] = nothing,
  605. first_ctx: bool = False,
  606. *,
  607. fullgraph: bool = False,
  608. error_on_graph_break: Optional[bool] = None,
  609. export: bool = False,
  610. dynamic: Optional[bool] = None,
  611. compiler_config: Optional[Any] = None,
  612. package: Optional[CompilePackage] = None,
  613. hooks: Optional[Hooks] = None,
  614. ) -> None:
  615. super().__init__()
  616. assert callable(callback) or callback is False or callback is None
  617. self.callback: DynamoCallback = callback
  618. self._backend_ctx_ctor = backend_ctx_ctor
  619. self.prior: Union[Unset, DynamoCallback] = unset
  620. self.first_ctx = first_ctx
  621. self.fullgraph = fullgraph
  622. self.error_on_graph_break = error_on_graph_break
  623. self.export = export
  624. self._dynamic = dynamic
  625. self.compiler_config = compiler_config
  626. self.cleanup_fns: list[Callable[[], Any]] = []
  627. self.enter_exit_hooks = []
  628. self._package = package
  629. self._hooks = hooks
  630. patch_fn()
  631. # Save the backends so that we can reset them during torch._dynamo.reset
  632. backend = innermost_backend(callback) # type: ignore[arg-type]
  633. cached_backends.setdefault(id(backend), backend) # type: ignore[arg-type]
  634. if dynamic is not None:
  635. self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
  636. if on_enter is not nothing:
  637. # this case is not common
  638. def call_on_enter() -> Callable[[], None]:
  639. on_enter()
  640. return nothing
  641. self.enter_exit_hooks.append(call_on_enter)
  642. if backend_ctx_ctor is not contextlib.nullcontext:
  643. # this case is not common
  644. def call_backend_ctx() -> functools.partial[Optional[bool]]:
  645. ctx = backend_ctx_ctor()
  646. ctx.__enter__()
  647. return functools.partial(ctx.__exit__, None, None, None)
  648. self.enter_exit_hooks.append(call_backend_ctx)
  649. def __enter__(self) -> None:
  650. if config.raise_on_ctx_manager_usage:
  651. raise RuntimeError(
  652. "torch._dynamo.optimize(...) is used with a context manager. "
  653. "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html "
  654. "to use torch._dynamo.optimize(...) as an annotation/decorator. "
  655. )
  656. self.prior = set_eval_frame(None)
  657. self.cleanup_fns = [enter() for enter in self.enter_exit_hooks]
  658. self.prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
  659. _is_skip_guard_eval_unsafe_stance()
  660. )
  661. _maybe_set_eval_frame(_callback_from_stance(self.callback))
  662. def __exit__(
  663. self,
  664. exc_type: Optional[type[BaseException]],
  665. exc_val: Optional[BaseException],
  666. exc_tb: Optional[types.TracebackType],
  667. ) -> Optional[bool]:
  668. assert self.prior is not unset
  669. set_eval_frame(None)
  670. set_skip_guard_eval_unsafe(self.prior_skip_guard_eval_unsafe)
  671. for cleanup in self.cleanup_fns:
  672. cleanup()
  673. self.cleanup_fns.clear()
  674. _maybe_set_eval_frame(_callback_from_stance(self.prior))
  675. self.prior = unset
  676. return None
  677. def __call__(self, fn: Any) -> Any:
  678. # public api for compiler config/options
  679. def get_compiler_config() -> Any:
  680. return self.compiler_config
  681. from .package import DynamoCache
  682. # If self._package is lazily initialized, we should check the dynamo cache now
  683. if config.caching_precompile:
  684. if self._package is not None and not self._package.is_initialized():
  685. fn_key = fn.forward if isinstance(fn, torch.nn.Module) else fn
  686. result = DynamoCache.load(fn_key)
  687. if result is None:
  688. # Create a fresh CompilePackage
  689. self._package.initialize(fn_key, None, ignore_inlined_sources=False)
  690. else:
  691. try:
  692. self._package.initialize(
  693. fn_key, result.dynamo, ignore_inlined_sources=False
  694. )
  695. self._package.install(result.backends)
  696. except RuntimeError:
  697. log.warning(
  698. "Failed to load entry from dynamo cache", exc_info=True
  699. )
  700. self._package.initialize(
  701. fn_key, None, ignore_inlined_sources=False
  702. )
  703. fn = innermost_fn(fn)
  704. def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any:
  705. from torch._dynamo.aot_compile import aot_compile_fullgraph
  706. if torch._inductor.config.force_disable_caches:
  707. raise RuntimeError(
  708. "Cannot precompile with torch._inductor.config.force_disable_caches=True; caching is required."
  709. )
  710. if not self.fullgraph:
  711. raise RuntimeError(
  712. "Graph breaks are not supported with aot compile. Please use torch.compile(fullgraph=True)."
  713. )
  714. if not callable(self.callback):
  715. raise RuntimeError("aot compile requires a callable dynamo callback.")
  716. assert self._hooks is not None
  717. return aot_compile_fullgraph(
  718. fn,
  719. example_inputs,
  720. hooks=self._hooks,
  721. backend=innermost_backend(self.callback),
  722. dynamic=self._dynamic,
  723. )
  724. # add context containing GraphModule to any GraphModule forward functions
  725. if isinstance(fn, GraphModule):
  726. # add context containing GraphModule to any GraphModule forward functions
  727. code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = (
  728. weakref.ref(fn)
  729. )
  730. # Optimize the forward method of torch.nn.Module object
  731. if isinstance(fn, torch.nn.Module):
  732. mod = fn
  733. new_mod = OptimizedModule(mod, self)
  734. # Save the function pointer to find the original callable while nesting
  735. # of decorators.
  736. new_mod._torchdynamo_orig_callable = mod.forward
  737. new_mod._torchdynamo_wrapper_id = id(new_mod)
  738. # when compiling torch.nn.Module,
  739. # provide public api OptimizedModule.get_compiler_config()
  740. assert not hasattr(new_mod, "get_compiler_config")
  741. new_mod.get_compiler_config = get_compiler_config
  742. return new_mod
  743. if inspect.isclass(fn):
  744. # User has wrapped the class with compile/disable decorator. Apply
  745. # disable to init/call method.
  746. cls_obj = fn
  747. cls_obj.__call__ = self(cls_obj.__call__)
  748. if issubclass(cls_obj, torch.nn.Module):
  749. # NN module variable tracker directly inlines the _call_impl.
  750. cls_obj._call_impl = self(cls_obj._call_impl)
  751. return cls_obj
  752. assert callable(fn), (
  753. f"A callable function is expected, but {type(fn)} is provided."
  754. )
  755. # NOTE [Top-level TorchInGraph functions]
  756. # Some callables (e.g. torch.exp) are represented as TorchInGraphFunctionVariable
  757. # when traced inside a frame. When such a function is passed directly to
  758. # torch.compile, we detect it here so we can force it through wrap_inline.
  759. from .variables import TorchInGraphFunctionVariable
  760. rule = trace_rules.lookup(fn)
  761. top_level_in_graph = isinstance(rule, type) and issubclass(
  762. rule, TorchInGraphFunctionVariable
  763. )
  764. try:
  765. filename = inspect.getsourcefile(fn)
  766. except TypeError:
  767. filename = None
  768. if config.debug_force_nested_calls:
  769. fn = external_utils.wrap_inline(fn)
  770. elif config.wrap_top_frame or (
  771. (filename is None or trace_rules.check(fn) or top_level_in_graph)
  772. and (
  773. getattr(fn, "__name__", "")
  774. not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"]
  775. )
  776. and filename not in DONT_WRAP_FILES
  777. ):
  778. # call to a builtin without a frame for us to capture
  779. fn = external_utils.wrap_inline(fn)
  780. def do_nothing(*arg: Any, **kwargs: Any) -> None:
  781. pass
  782. callback: Callable[..., Any] = do_nothing
  783. if hasattr(self, "callback"):
  784. callback = self.callback # type: ignore[assignment]
  785. is_jit_tracing = torch._C._is_tracing
  786. is_fx_symbolic_tracing = torch.fx._symbolic_trace.is_fx_symbolic_tracing
  787. @functools.wraps(fn)
  788. def compile_wrapper(*args: Any, **kwargs: Any) -> Any:
  789. prior = set_eval_frame(None)
  790. prior_eval_frame_override: _EvalFrameOverride | None = None
  791. if self.fullgraph:
  792. prior_eval_frame_override = set_eval_frame_override(
  793. _get_eval_frame_override()
  794. )
  795. try:
  796. # We shouldn't compile inside kernel invocation.
  797. if tracing_context := torch._guards.TracingContext.try_get():
  798. if (
  799. tracing_context.fake_mode is not None
  800. and tracing_context.fake_mode.in_kernel_invocation
  801. ):
  802. return fn(*args, **kwargs)
  803. # Skip nested compile during export (but not HOP internal compile)
  804. # Only skip if there's an active TracingContext (nested), not for top-level export
  805. if (
  806. torch.compiler.is_exporting()
  807. and not config.force_compile_during_fx_trace
  808. ):
  809. from torch._higher_order_ops.utils import _in_hop_compile
  810. if not _in_hop_compile():
  811. if torch._guards.TracingContext.try_get() is not None:
  812. return fn(*args, **kwargs)
  813. # Skip nested compile - just inline the function
  814. if (
  815. is_fx_symbolic_tracing()
  816. and not config.force_compile_during_fx_trace
  817. ):
  818. if config.error_on_nested_fx_trace:
  819. raise RuntimeError(
  820. "Detected that you are using FX to symbolically trace "
  821. "a dynamo-optimized function. This is not supported at the moment."
  822. )
  823. else:
  824. return fn(*args, **kwargs)
  825. if is_jit_tracing():
  826. raise RuntimeError(
  827. "Detected that you are using FX to torch.jit.trace "
  828. "a dynamo-optimized function. This is not supported at the moment."
  829. )
  830. cleanups = [enter() for enter in self.enter_exit_hooks]
  831. prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
  832. _is_skip_guard_eval_unsafe_stance()
  833. )
  834. prior_error_on_graph_break = None
  835. if not self.fullgraph and self.error_on_graph_break is not None:
  836. prior_error_on_graph_break = _get_error_on_graph_break()
  837. _set_error_on_graph_break(self.error_on_graph_break)
  838. # Ensure that if an assertion occurs after graph pushes
  839. # something onto the DynamicLayerStack then we pop it off (the
  840. # constructed graph code isn't guarded with try/finally).
  841. #
  842. # This used to be a context but putting a `with` here is a noticeable
  843. # perf regression (#126293)
  844. saved_dynamic_layer_stack_depth = (
  845. torch._C._functorch.get_dynamic_layer_stack_depth()
  846. )
  847. _maybe_set_eval_frame(_callback_from_stance(callback))
  848. try:
  849. return fn(*args, **kwargs)
  850. except (Unsupported, UncapturedHigherOrderOpError) as e:
  851. if config.verbose:
  852. raise
  853. # strip internal tracebacks from causes
  854. cur_exn: BaseException = e
  855. while cur_exn.__cause__ is not None:
  856. cur_exn.__cause__.with_traceback(None)
  857. cur_exn = cur_exn.__cause__
  858. raise e.with_traceback(None) from e.__cause__ # User compiler error
  859. except ShortenTraceback as e:
  860. # Failures in the backend likely don't have useful
  861. # data in the TorchDynamo frames, so we strip them out.
  862. raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
  863. finally:
  864. # Restore the dynamic layer stack depth if necessary.
  865. set_eval_frame(None)
  866. if prior_error_on_graph_break is not None:
  867. _set_error_on_graph_break(prior_error_on_graph_break)
  868. if prior_eval_frame_override is not None:
  869. set_eval_frame_override(prior_eval_frame_override)
  870. torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
  871. saved_dynamic_layer_stack_depth
  872. )
  873. set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe)
  874. for cleanup in cleanups:
  875. cleanup()
  876. finally:
  877. _maybe_set_eval_frame(prior)
  878. # hooks to properly handle inlining
  879. if self.error_on_graph_break is not None:
  880. compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
  881. external_utils.wrap_inline_with_error_on_graph_break(
  882. fn, self.error_on_graph_break
  883. )
  884. )
  885. else:
  886. compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined]
  887. # Save the function pointer to find the original callable while nesting
  888. # of decorators.
  889. compile_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
  890. compile_wrapper._torchdynamo_wrapper_id = id(compile_wrapper) # type: ignore[attr-defined]
  891. # when compiling user function instead of nn.Module
  892. # provide public api _fn.get_compiler_config()
  893. assert not hasattr(compile_wrapper, "get_compiler_config")
  894. compile_wrapper.get_compiler_config = get_compiler_config # type: ignore[attr-defined]
  895. if torch._dynamo.config.enable_aot_compile:
  896. compile_wrapper.aot_compile = aot_compile # type: ignore[attr-defined]
  897. # If the function is called using torch._dynamo.optimize decorator, we
  898. # should prevent any type of skipping.
  899. if callback not in (None, False):
  900. if not hasattr(fn, "__code__"):
  901. raise RuntimeError(
  902. textwrap.dedent(
  903. """
  904. torch._dynamo.optimize is called on a non function object.
  905. If this is a callable class, please wrap the relevant code into a function and optimize the
  906. wrapper function.
  907. >> class CallableClass:
  908. >> def __init__(self) -> None:
  909. >> super().__init__()
  910. >> self.relu = torch.nn.ReLU()
  911. >>
  912. >> def __call__(self, x):
  913. >> return self.relu(torch.sin(x))
  914. >>
  915. >> def print_hello(self):
  916. >> print("Hello world")
  917. >>
  918. >> mod = CallableClass()
  919. If you want to optimize the __call__ function and other code, wrap that up in a function
  920. >> def wrapper_fn(x):
  921. >> y = mod(x)
  922. >> return y.sum()
  923. and then optimize the wrapper_fn
  924. >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn)
  925. """
  926. )
  927. )
  928. always_optimize_code_objects[fn.__code__] = True
  929. return compile_wrapper
  930. class OptimizeContext(_TorchDynamoContext):
  931. def __init__(
  932. self,
  933. callback: DynamoCallback,
  934. backend_ctx_ctor: Callable[[], contextlib.AbstractContextManager[Any]],
  935. first_ctx: bool = False,
  936. *,
  937. fullgraph: bool = False,
  938. error_on_graph_break: Optional[bool] = None,
  939. export: bool = False,
  940. dynamic: Optional[bool] = None,
  941. compiler_config: Optional[Any] = None,
  942. rebuild_ctx: Optional[
  943. Callable[[], Union[OptimizeContext, _NullDecorator]]
  944. ] = None,
  945. package: Optional[CompilePackage] = None,
  946. hooks: Optional[Hooks] = None,
  947. ) -> None:
  948. def on_enter() -> None:
  949. install_generation_tagging_init()
  950. super().__init__(
  951. callback=callback,
  952. on_enter=on_enter,
  953. backend_ctx_ctor=backend_ctx_ctor,
  954. patch_fn=TorchPatcher.patch,
  955. first_ctx=first_ctx,
  956. fullgraph=fullgraph,
  957. error_on_graph_break=error_on_graph_break,
  958. export=export,
  959. dynamic=dynamic,
  960. compiler_config=compiler_config,
  961. package=package,
  962. hooks=hooks,
  963. )
  964. if config.compiled_autograd:
  965. _dynamic = self._dynamic
  966. if _dynamic is None:
  967. _dynamic = not torch._dynamo.config.assume_static_by_default
  968. def call_compiled_autograd() -> functools.partial[Optional[bool]]:
  969. assert rebuild_ctx is not None
  970. compiler_fn = rebuild_ctx()
  971. ctx = torch._dynamo.compiled_autograd._enable(
  972. compiler_fn,
  973. # pyrefly: ignore [bad-argument-type]
  974. dynamic=_dynamic,
  975. ignore_active_disable_ctx=False,
  976. )
  977. ctx.__enter__()
  978. return functools.partial(ctx.__exit__, None, None, None)
  979. self.enter_exit_hooks.append(call_compiled_autograd)
  980. def __reduce__(
  981. self,
  982. ) -> tuple[type[OptimizeContext], tuple[Any, ...], dict[str, Any]]:
  983. return (
  984. self.__class__,
  985. (self.callback, self._backend_ctx_ctor, self.first_ctx),
  986. {
  987. "export": self.export,
  988. "dynamic": self._dynamic,
  989. "compiler_config": self.compiler_config,
  990. },
  991. )
  992. class RunOnlyContext(_TorchDynamoContext):
  993. def __init__(self) -> None:
  994. # cudagraph trees relies on generation increment
  995. def on_enter() -> None:
  996. torch._dynamo.mutation_guard.GenerationTracker.generation += 1
  997. super().__init__(callback=False, on_enter=on_enter)
  998. def __reduce__(self) -> tuple[type[RunOnlyContext], tuple[Any, ...]]:
  999. return (self.__class__, ())
  1000. class DisableContext(_TorchDynamoContext):
  1001. def __init__(self, msg: Optional[str] = None, wrapping: bool = True) -> None:
  1002. super().__init__(callback=None)
  1003. self.msg = msg
  1004. self.wrapping = wrapping
  1005. def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]:
  1006. # Earlier this code was in the base class _TorchDynamoContext. But we
  1007. # moved it here to have better code organization. For disable, we just
  1008. # want the callback to be None. We don't have to check trace_rules or
  1009. # create any wrapper.
  1010. fn = innermost_fn(fn)
  1011. if isinstance(fn, torch.nn.Module):
  1012. mod = fn
  1013. new_mod = OptimizedModule(mod, self)
  1014. new_mod._torchdynamo_orig_callable = mod.forward
  1015. new_mod._torchdynamo_wrapper_id = id(new_mod)
  1016. return new_mod
  1017. if isinstance(fn, type):
  1018. # User has wrapped the class with compile/disable decorator. Apply
  1019. # disable to init/call method.
  1020. cls_obj = fn
  1021. # Disable on init is useful for reconstruction of bytecodes where we
  1022. # want to prevent Dynamo from tracing into the init function. Check
  1023. # test_reconstruction in test_model_output.py.
  1024. cls_obj.__init__ = self(cls_obj.__init__) # type: ignore[misc]
  1025. cls_obj.__call__ = self(cls_obj.__call__)
  1026. if issubclass(cls_obj, torch.nn.Module):
  1027. # NN module variable tracker directly inlines the _call_impl. Disable it.
  1028. # pyrefly: ignore [missing-attribute]
  1029. cls_obj._call_impl = self(cls_obj._call_impl)
  1030. return cls_obj
  1031. assert callable(fn), (
  1032. f"A callable function is expected, but {type(fn)} is provided."
  1033. )
  1034. def _fn(*args: Any, **kwargs: Any) -> Any:
  1035. prior = set_eval_frame(None)
  1036. try:
  1037. _maybe_set_eval_frame(_callback_from_stance(self.callback))
  1038. try:
  1039. fn_name = getattr(fn, "__name__", type(fn).__name__)
  1040. # Skip annotation for __torch_dispatch__ to avoid polluting
  1041. # node metadata during export. The disable on __torch_dispatch__
  1042. # is an internal implementation detail, not user-facing.
  1043. # TODO: Ideally we shouldn't need this check because nested
  1044. # annotate() calls shouldn't override existing keys.
  1045. if (
  1046. torch.compiler.is_exporting()
  1047. and fn_name != "__torch_dispatch__"
  1048. ):
  1049. with fx_traceback.annotate(
  1050. {
  1051. "_torchdynamo_disable": True,
  1052. "_torchdynamo_disable_recursive": True,
  1053. "_torchdynamo_disable_method": fn_name,
  1054. }
  1055. ):
  1056. return fn(*args, **kwargs)
  1057. return fn(*args, **kwargs)
  1058. finally:
  1059. set_eval_frame(None)
  1060. finally:
  1061. _maybe_set_eval_frame(prior)
  1062. # Under some circumstances (e.g. precompile) we can end up calling @disable
  1063. # decorator in generated bytecode and trigger recompile. This is due to the
  1064. # fact that the old callback from torch.compile() is still active and under
  1065. # this circumstance we will trigger a failure with set_stance("fail_on_recompile").
  1066. # Therefore we want to skip calling into any frame in this case.
  1067. if self.wrapping:
  1068. _fn = functools.wraps(fn)(_fn)
  1069. _fn._torchdynamo_disable = True # type: ignore[attr-defined]
  1070. _fn._torchdynamo_disable_msg = self.msg # type: ignore[attr-defined]
  1071. # Save the function pointer to find the original callable while nesting
  1072. # of decorators.
  1073. _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
  1074. _fn._torchdynamo_wrapper_id = id(_fn) # type: ignore[attr-defined]
  1075. _fn._torchdynamo_disable_recursive = True # type: ignore[attr-defined]
  1076. return _fn
  1077. def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]:
  1078. return (self.__class__, ())
  1079. def _optimize_catch_errors(
  1080. compile_fn: convert_frame.ConvertFrameProtocol,
  1081. hooks: Hooks,
  1082. backend_ctx_ctor: Callable[
  1083. [], contextlib.AbstractContextManager[Any]
  1084. ] = null_context,
  1085. fullgraph: bool = False,
  1086. error_on_graph_break: Optional[bool] = None,
  1087. export: bool = False,
  1088. dynamic: Optional[bool] = None,
  1089. compiler_config: Optional[Any] = None,
  1090. rebuild_ctx: Optional[Callable[[], Union[OptimizeContext, _NullDecorator]]] = None,
  1091. package: Optional[CompilePackage] = None,
  1092. ) -> OptimizeContext:
  1093. return OptimizeContext(
  1094. convert_frame.catch_errors_wrapper(compile_fn, hooks),
  1095. backend_ctx_ctor=backend_ctx_ctor,
  1096. first_ctx=True,
  1097. fullgraph=fullgraph,
  1098. error_on_graph_break=error_on_graph_break,
  1099. export=export,
  1100. dynamic=dynamic,
  1101. compiler_config=compiler_config,
  1102. rebuild_ctx=rebuild_ctx,
  1103. package=package,
  1104. hooks=hooks,
  1105. )
  1106. def get_compiler_fn(
  1107. compiler_fn: Union[str, Callable[..., Any], None],
  1108. ) -> WrapBackendDebug:
  1109. from .repro.after_dynamo import wrap_backend_debug
  1110. if compiler_fn is None:
  1111. # Special case None to avoid crashing in hasattr
  1112. compiler_str = None
  1113. elif hasattr(compiler_fn, "compiler_name"):
  1114. compiler_str = compiler_fn.compiler_name # type: ignore[union-attr]
  1115. assert isinstance(compiler_str, str)
  1116. elif isinstance(compiler_fn, str):
  1117. compiler_str = compiler_fn
  1118. else:
  1119. compiler_str = None
  1120. compiler_fn = lookup_backend(compiler_fn) # type: ignore[arg-type]
  1121. return wrap_backend_debug(compiler_fn, compiler_str)
  1122. class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
  1123. def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]:
  1124. assert callable(fn), (
  1125. f"A callable function is expected, but {type(fn)} is provided."
  1126. )
  1127. return fn
  1128. # Make dynamo graph to have same input/output spec as user code
  1129. def argument_names(
  1130. f_sig: inspect.Signature,
  1131. args: Union[list[Any], tuple[Any, ...]],
  1132. kwargs: dict[str, Any],
  1133. ) -> list[str]:
  1134. def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
  1135. # Get a list of Parameter objects from the Signature object
  1136. params = list(sig.parameters.values())
  1137. # Separate positional arguments, keyword-only arguments and varargs/varkw
  1138. args = [
  1139. p.name for p in params if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  1140. ]
  1141. kwonlyargs = [
  1142. p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
  1143. ]
  1144. varargs = next(
  1145. (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
  1146. None,
  1147. )
  1148. varkw = next(
  1149. (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
  1150. None,
  1151. )
  1152. # Get default values for positional arguments and keyword-only arguments
  1153. defaults = tuple(
  1154. p.default
  1155. for p in params
  1156. if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  1157. and p.default is not inspect.Parameter.empty
  1158. )
  1159. kwonlydefaults = {
  1160. p.name: p.default
  1161. for p in params
  1162. if p.kind == inspect.Parameter.KEYWORD_ONLY
  1163. and p.default is not inspect.Parameter.empty
  1164. }
  1165. # Get annotations for parameters and return value
  1166. # pyrefly: ignore [implicit-any]
  1167. annotations = {}
  1168. if sig.return_annotation:
  1169. annotations = {"return": sig.return_annotation}
  1170. for parameter in params:
  1171. annotations[parameter.name] = parameter.annotation
  1172. # Return a FullArgSpec object with the extracted attributes
  1173. return inspect.FullArgSpec(
  1174. args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
  1175. )
  1176. fullargspec = signature_to_fullargspec(f_sig)
  1177. # 1. Map `args` 1-to-1 to positional arguments in original signature.
  1178. input_strs = fullargspec.args[: len(args)]
  1179. if len(args) > len(fullargspec.args):
  1180. # 2. If there are more arguments left in `args`, they map to varargs in original
  1181. # signature. Assign names as {varargs}_0, {varargs}_1, ...
  1182. assert fullargspec.varargs is not None, "More arguments than expected"
  1183. input_strs += [
  1184. f"{fullargspec.varargs}_{i}" for i in range(len(args) - len(input_strs))
  1185. ]
  1186. elif len(args) < len(fullargspec.args):
  1187. # 3. If there are fewer arguments in `args` than `fullargspec.args`,
  1188. # it implies these are arguments either with default values, or provided in
  1189. # `kwargs`. The former can be safely ignored. Because Dynamo.export does not
  1190. # export them as part of the function signature. The latter will be handled
  1191. # in the next step.
  1192. for unprovided_arg in fullargspec.args[
  1193. len(args) : -len(fullargspec.defaults or [])
  1194. ]:
  1195. assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
  1196. # 4. Keyword arguments provided in `kwargs`.
  1197. input_strs += list(kwargs.keys())
  1198. # 5. Keyword-only arguments with default values if not provided are not exported
  1199. # as part of the function signature.
  1200. for kwonly_arg in fullargspec.kwonlyargs:
  1201. kwonlydefaults = fullargspec.kwonlydefaults or {}
  1202. assert kwonly_arg in kwargs or kwonly_arg in kwonlydefaults, (
  1203. f"Missing keyword only argument {kwonly_arg}"
  1204. )
  1205. return input_strs
  1206. def check_if_dynamo_supported() -> None:
  1207. if sys.version_info >= (3, 15):
  1208. raise RuntimeError("Python 3.15+ not yet supported for torch.compile")
  1209. elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < (
  1210. 3,
  1211. 13,
  1212. 3,
  1213. ):
  1214. raise RuntimeError(
  1215. "torch.compile is not supported on Python < 3.13.3 built with GIL disabled. "
  1216. "Please use Python 3.13.3+."
  1217. )
  1218. def is_dynamo_supported() -> bool:
  1219. try:
  1220. check_if_dynamo_supported()
  1221. return True
  1222. except Exception:
  1223. return False
  1224. def check_if_inductor_supported() -> None:
  1225. check_if_dynamo_supported()
  1226. def is_inductor_supported() -> bool:
  1227. try:
  1228. check_if_inductor_supported()
  1229. return True
  1230. except Exception:
  1231. return False
  1232. def check_for_incompatible_configs() -> None:
  1233. # Some of the configs should be mutually exclusive
  1234. assert not (config.suppress_errors and config.fail_on_recompile_limit_hit), (
  1235. "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time."
  1236. )
  1237. def optimize(*args: Any, **kwargs: Any) -> Union[OptimizeContext, _NullDecorator]:
  1238. def rebuild_ctx() -> Union[OptimizeContext, _NullDecorator]:
  1239. ca_kwargs_override = config.compiled_autograd_kwargs_override
  1240. if ca_kwargs_override:
  1241. # NOTE: The process of translating other `torch.compile` kwargs to `torch._dynamo.optimize` kwargs
  1242. # is more complicated, we will add it in the future when needed.
  1243. assert set(ca_kwargs_override.keys()) == {"fullgraph"}, (
  1244. f"Only `fullgraph` kwarg override is supported for now, but got {ca_kwargs_override.keys()}"
  1245. )
  1246. kwargs["nopython"] = ca_kwargs_override["fullgraph"]
  1247. return optimize(*args, **kwargs)
  1248. return _optimize(rebuild_ctx, *args, **kwargs)
  1249. def _optimize(
  1250. rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]],
  1251. backend: Union[str, Callable[..., Any]] = "inductor",
  1252. *,
  1253. nopython: bool = False,
  1254. error_on_graph_break: Optional[bool] = None,
  1255. guard_export_fn: Optional[Callable[[_guards.GuardsSet], None]] = None,
  1256. guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
  1257. guard_filter_fn: Callable[[Sequence[GuardFilterEntry]], Sequence[bool]]
  1258. | None = None,
  1259. disable: bool = False,
  1260. dynamic: Optional[bool] = None,
  1261. package: Optional[CompilePackage] = None,
  1262. ) -> Union[OptimizeContext, _NullDecorator]:
  1263. """
  1264. The main entrypoint of TorchDynamo. Do graph capture and call
  1265. backend() to optimize extracted graphs.
  1266. Args:
  1267. backend: One of the two things:
  1268. - Either, a function/callable taking a torch.fx.GraphModule and
  1269. example_inputs and returning a python callable that runs the
  1270. graph faster.
  1271. One can also provide additional context for the backend, like
  1272. torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
  1273. See AOTAutogradMemoryEfficientFusionWithContext for the usage.
  1274. - Or, a string backend name in `torch._dynamo.list_backends()`
  1275. nopython: If True, graph breaks will be errors and there will
  1276. be a single whole-program graph.
  1277. error_on_graph_break: If not None, the current `error_on_graph_break` setting is set to the given value.
  1278. See `torch._dynamo.error_on_graph_break()` for more details on what `error_on_graph_break` means.
  1279. Unlike `nopython=True` (i.e. `fullgraph=True`), there is no guarantee of a single whole-program graph.
  1280. If `nopython` is True, `error_on_graph_break` does nothing.
  1281. disable: If True, turn this decorator into a no-op
  1282. dynamic: If True, upfront compile as dynamic a kernel as possible. If False,
  1283. disable all dynamic shapes support (always specialize). If None, automatically
  1284. detect when sizes vary and generate dynamic kernels upon recompile.
  1285. Example Usage::
  1286. @torch._dynamo.optimize()
  1287. def toy_example(a, b): ...
  1288. """
  1289. check_if_dynamo_supported()
  1290. check_for_incompatible_configs()
  1291. # Note: The hooks object could be global instead of passed around, *however* that would make
  1292. # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
  1293. # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
  1294. # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an
  1295. # easier to understand UX at the cost of a little more plumbing on our end.
  1296. hooks = Hooks(
  1297. guard_export_fn=guard_export_fn,
  1298. guard_fail_fn=guard_fail_fn,
  1299. guard_filter_fn=guard_filter_fn,
  1300. )
  1301. torch._C._log_api_usage_once("torch._dynamo.optimize")
  1302. if (
  1303. disable
  1304. or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1"
  1305. or (not justknobs_check("pytorch/compiler:enable_dynamo"))
  1306. ):
  1307. return _NullDecorator()
  1308. if nopython and not config.debug_force_graph_break_on_leaf_return:
  1309. return optimize_assert(
  1310. backend,
  1311. dynamic=dynamic,
  1312. hooks=hooks,
  1313. rebuild_ctx=rebuild_ctx,
  1314. package=package,
  1315. )
  1316. backend = get_compiler_fn(backend)
  1317. # Find if backend has any extra context manager
  1318. backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
  1319. # The backend function is stashed in the callable returned by
  1320. # _optimize_catch_errors in the field _torchdynamo_orig_backend. This can
  1321. # be used by eval_frame.c to insert a guard on the backend.
  1322. # With CachingPrecompile, instantiate an uninitialized CompilePackage
  1323. # which gets initialized by _optimize_catch_errors.__call__ once we have a function
  1324. if config.caching_precompile and package is None:
  1325. from .package import CompilePackage
  1326. package = CompilePackage(fn=None, dynamo=None, ignore_inlined_sources=False)
  1327. return _optimize_catch_errors(
  1328. convert_frame.convert_frame(
  1329. # pyrefly: ignore [bad-argument-type]
  1330. backend,
  1331. hooks,
  1332. package=package,
  1333. ),
  1334. hooks,
  1335. backend_ctx_ctor,
  1336. fullgraph=False,
  1337. error_on_graph_break=error_on_graph_break
  1338. and not config.debug_force_graph_break_on_leaf_return,
  1339. dynamic=dynamic,
  1340. compiler_config=(
  1341. backend.get_compiler_config()
  1342. if hasattr(backend, "get_compiler_config")
  1343. else None
  1344. ),
  1345. rebuild_ctx=rebuild_ctx,
  1346. package=package,
  1347. )
  1348. # TODO(voz): Consider making "explain" output alongside a run / part of a run
  1349. @patch("torch._dynamo.symbolic_convert.explain", True)
  1350. def explain(f: Callable[..., Any], *extra_args: Any, **extra_kwargs: Any) -> Any:
  1351. from .backends.debugging import ExplainOutput
  1352. def inner(*args: Any, **kwargs: Any) -> ExplainOutput:
  1353. # TODO(voz): Do we want a decorator for this?
  1354. from . import reset # type: ignore[attr-defined]
  1355. reset()
  1356. graphs: list[torch.fx.GraphModule] = []
  1357. break_reasons: list[Any] = []
  1358. op_count: int = 0
  1359. ops_per_graph: list[list[Target]] = []
  1360. out_guards: list[_guards.Guard] = []
  1361. def dynamo_graph_accumulating_compiler(
  1362. gm: torch.fx.GraphModule, example_inputs: Any
  1363. ) -> Callable[..., Any]:
  1364. from .backends.debugging import _explain_graph_detail
  1365. nonlocal graphs
  1366. nonlocal op_count
  1367. nonlocal ops_per_graph
  1368. nonlocal break_reasons
  1369. gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail(
  1370. gm, graphs, op_count, ops_per_graph, break_reasons
  1371. )
  1372. return gm.forward
  1373. def guard_export_print(guards: Iterable[_guards.Guard]) -> None:
  1374. nonlocal out_guards
  1375. out_guards.extend(guards)
  1376. opt_f = optimize(
  1377. dynamo_graph_accumulating_compiler,
  1378. nopython=False,
  1379. guard_export_fn=guard_export_print,
  1380. )(f)
  1381. # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
  1382. opt_f(*args, **kwargs)
  1383. graph_count = len(graphs)
  1384. graph_break_count = graph_count - 1
  1385. compile_time = compile_times(repr="str")
  1386. # TODO(voz): Do we want a decorator for this?
  1387. reset()
  1388. return ExplainOutput(
  1389. graphs,
  1390. graph_count,
  1391. graph_break_count,
  1392. break_reasons,
  1393. op_count,
  1394. ops_per_graph,
  1395. out_guards,
  1396. compile_time,
  1397. )
  1398. if extra_args or extra_kwargs:
  1399. warnings.warn(
  1400. "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. "
  1401. "If you don't migrate, we may break your explain call in the future if your user defined kwargs "
  1402. "conflict with future kwargs added to explain(f).",
  1403. FutureWarning,
  1404. stacklevel=2,
  1405. )
  1406. return inner(*extra_args, **extra_kwargs)
  1407. else:
  1408. return inner
  1409. class FlattenInputOutputSignature(torch.fx.Transformer):
  1410. def __init__(
  1411. self,
  1412. m: torch.fx.GraphModule,
  1413. flat_args: list[Any],
  1414. matched_input_elements_positions: list[int],
  1415. flat_results: Sequence[Any],
  1416. matched_output_elements_positions: list[int],
  1417. example_fake_inputs: list[torch.Tensor],
  1418. flat_args_dynamic_dims: list[set[int]],
  1419. fake_mode: Optional[fake_tensor.FakeTensorMode] = None,
  1420. ) -> None:
  1421. super().__init__(m)
  1422. assert len(flat_args_dynamic_dims) == len(flat_args)
  1423. matched_input_elements_to_fake = {
  1424. val: example_fake_inputs[ix]
  1425. for ix, val in enumerate(matched_input_elements_positions)
  1426. }
  1427. self.new_args = []
  1428. for i in range(len(flat_args)):
  1429. arg = super().placeholder(f"arg{i}", (), {})
  1430. if i in matched_input_elements_to_fake:
  1431. arg.node.meta["val"] = matched_input_elements_to_fake[i]
  1432. else:
  1433. # Fill node.meta["val"] with faketensor from the input,
  1434. # if it's not found in matched_input_elements_positions
  1435. if fake_mode is not None and isinstance(flat_args[i], torch.Tensor):
  1436. # TODO(zhxchen17) Also preserve all the user constraints here.
  1437. arg.node.meta["val"] = fake_mode.from_tensor(
  1438. flat_args[i],
  1439. symbolic_context=StatelessSymbolicContext(
  1440. dynamic_sizes=[
  1441. (
  1442. DimDynamic.DYNAMIC
  1443. if d in flat_args_dynamic_dims[i]
  1444. else DimDynamic.STATIC
  1445. )
  1446. for d in range(len(flat_args[i].shape))
  1447. ],
  1448. constraint_sizes=[None] * len(flat_args[i].shape),
  1449. ),
  1450. )
  1451. elif isinstance(flat_args[i], _IntWrapper):
  1452. arg.node.meta["val"] = flat_args[i].val
  1453. else:
  1454. arg.node.meta["val"] = flat_args[i]
  1455. self.new_args.append(arg)
  1456. self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions)
  1457. self.matched_output_elements_positions = matched_output_elements_positions
  1458. self.flat_results = flat_results
  1459. def placeholder(
  1460. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  1461. ) -> Any:
  1462. arg = next(self.old_args_gen)
  1463. if "val" in self.current_node.meta:
  1464. arg.node.meta["val"] = self.current_node.meta["val"]
  1465. if "tensor_dict" in self.current_node.meta:
  1466. arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
  1467. if "example_value" in self.current_node.meta:
  1468. # NB: intentionally do not use set_example_value
  1469. arg.node.meta["example_value"] = self.current_node.meta["example_value"]
  1470. if "unbacked_bindings" in self.current_node.meta:
  1471. arg.node.meta["unbacked_bindings"] = self.current_node.meta[
  1472. "unbacked_bindings"
  1473. ]
  1474. return arg
  1475. def output(
  1476. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  1477. ) -> Any:
  1478. dynamo_result_flat = args[0]
  1479. lookup = [*dynamo_result_flat, *self.new_args] # type: ignore[misc]
  1480. # pyrefly: ignore [implicit-any]
  1481. new_results_flat = []
  1482. for i in range(len(self.flat_results)):
  1483. if self.matched_output_elements_positions[i] is not None:
  1484. new_results_flat.append(
  1485. lookup[self.matched_output_elements_positions[i]]
  1486. )
  1487. else:
  1488. const_val = self.flat_results[i]
  1489. assert isinstance(const_val, tuple(common_constant_types))
  1490. new_results_flat.append(const_val)
  1491. return super().output(target, (new_results_flat,), {})
  1492. def run_node(self, n: Node) -> Any:
  1493. self.current_node = n
  1494. result_proxy = super().run_node(n)
  1495. if "val" in self.current_node.meta:
  1496. result_proxy.node.meta["val"] = self.current_node.meta["val"]
  1497. if "example_value" in self.current_node.meta:
  1498. # NB: intentionally do not use set_example_value
  1499. result_proxy.node.meta["example_value"] = self.current_node.meta[
  1500. "example_value"
  1501. ]
  1502. if "unbacked_bindings" in self.current_node.meta:
  1503. result_proxy.node.meta["unbacked_bindings"] = self.current_node.meta[
  1504. "unbacked_bindings"
  1505. ]
  1506. if self.current_node.op != "output":
  1507. result_proxy.node._rename(
  1508. getattr(self.current_node, "name", result_proxy.node.name)
  1509. )
  1510. return result_proxy
  1511. def transform(self) -> torch.fx.GraphModule:
  1512. result_gm = super().transform()
  1513. if "dynamo_flat_name_to_original_fqn" in self.module.meta: # type: ignore[operator]
  1514. result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ # type: ignore[index]
  1515. "dynamo_flat_name_to_original_fqn" # type: ignore[index]
  1516. ]
  1517. if "dynamo_compile_id" in self.module.meta: # type: ignore[operator]
  1518. result_gm.meta["dynamo_compile_id"] = self.module.meta["dynamo_compile_id"] # type: ignore[index]
  1519. return result_gm
  1520. class ExportResult(NamedTuple):
  1521. graph_module: torch.fx.GraphModule
  1522. guards: _guards.GuardsSet
  1523. # NB: Do not add new fields without overriding __iter__; people are
  1524. # destructuring so it is BC-breaking
  1525. # NOTE: this function only supports graphs created by Dynamo's OutputGraph module
  1526. def check_signature_rewritable(graph: torch.fx.GraphModule) -> None:
  1527. # pyrefly: ignore [implicit-any]
  1528. input_errors = []
  1529. for node in graph.graph.find_nodes(op="placeholder"):
  1530. # set in OutputGraph._call_user_compiler
  1531. assert hasattr(node, "_dynamo_source")
  1532. assert hasattr(graph, "_source_to_user_stacks")
  1533. # NOTE: We can safely ignore these type warnings if and only if
  1534. # the function is made from OutputGraph (checked in the assertions)
  1535. source = node._dynamo_source # type: ignore[attr-defined]
  1536. user_stacks = graph._source_to_user_stacks.get(source) # type: ignore[operator, union-attr]
  1537. if user_stacks is None:
  1538. continue
  1539. assert len(user_stacks) > 0
  1540. # In some cases we may not have a useful stack. Look for a
  1541. # useful stack
  1542. stack = None
  1543. for s in user_stacks:
  1544. if len(s) == 0:
  1545. continue
  1546. stack = s
  1547. break
  1548. if stack is None:
  1549. msg = f"{source.name}, a closed over free variable"
  1550. else:
  1551. tb = "".join(traceback.format_list(stack))
  1552. extra = ""
  1553. if len(user_stacks) > 1:
  1554. extra = f"(elided {len(user_stacks) - 1} more accesses)"
  1555. msg = f"{source.name}, accessed at:\n{tb}{extra}"
  1556. # TODO: option to print ALL of the stack traces at once
  1557. input_errors.append(msg)
  1558. if input_errors:
  1559. raise UserError(
  1560. UserErrorType.INVALID_INPUT,
  1561. "Cannot export model which references tensors that are neither "
  1562. "buffers/parameters/constants nor are direct inputs. For each tensor, if you'd "
  1563. "like this tensor to be an explicit input, add it as a dummy argument "
  1564. "to the top-level model definition you are exporting; if you would "
  1565. "like its value to be embedded as an exported constant, wrap its access "
  1566. "in a function marked with @assume_constant_result.\n\n"
  1567. + "\n\n".join(input_errors),
  1568. )
  1569. def check_user_input_output(flat_values: list[Any], error_type: UserErrorType) -> None:
  1570. supported_types = [
  1571. torch.Tensor,
  1572. torch.SymInt,
  1573. torch.SymFloat,
  1574. torch.SymBool,
  1575. torch._C.ScriptObject,
  1576. _IntWrapper,
  1577. ] + list(common_constant_types)
  1578. def is_supported_type(val: Any) -> bool:
  1579. return isinstance(val, tuple(supported_types)) or is_opaque_type(type(val))
  1580. value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
  1581. # We only check that the outputs are not None. Inputs can be None.
  1582. for v in flat_values:
  1583. if not is_supported_type(v):
  1584. if error_type == UserErrorType.INVALID_INPUT and v is None:
  1585. continue
  1586. raise UserError(
  1587. error_type,
  1588. f"It looks like one of the {value_type}s with type `{type(v)}` "
  1589. "is not supported or pytree-flattenable. \n"
  1590. f"Exported graphs {value_type}s can only contain the "
  1591. f"following supported types: {supported_types}. \n"
  1592. "If you are using a custom class object, "
  1593. "please register a pytree_flatten/unflatten function "
  1594. "using `torch.utils._pytree.register_pytree_node` or "
  1595. "`torch.export.register_dataclass`.",
  1596. )
  1597. def rewrite_signature(
  1598. f_sig: inspect.Signature,
  1599. graph: torch.fx.GraphModule,
  1600. fake_mode: Optional[fake_tensor.FakeTensorMode],
  1601. flat_args: list[Any],
  1602. in_spec: pytree.TreeSpec,
  1603. example_fake_inputs: list[Any],
  1604. graph_captured_input: Iterable[Any],
  1605. graph_captured_output: Optional[Iterable[Any]],
  1606. dynamo_traced_result: Any,
  1607. flat_args_dynamic_dims: list[set[int]],
  1608. ) -> torch.fx.GraphModule:
  1609. orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
  1610. check_user_input_output(flat_args, UserErrorType.INVALID_INPUT)
  1611. flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
  1612. check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT)
  1613. def check_optional_input_and_error(f_sig: inspect.Signature) -> None:
  1614. # Check if function has optional input.
  1615. for name, param in f_sig.parameters.items():
  1616. if param.default is not inspect.Parameter.empty:
  1617. import torch._dynamo.graph_break_hints as graph_break_hints
  1618. from torch._dynamo.exc import unimplemented
  1619. log.error(
  1620. "Parameter %s is optional with a default value of %s",
  1621. name,
  1622. param.default,
  1623. )
  1624. unimplemented(
  1625. gb_type="rewrite_signature: cannot trace optional function input",
  1626. context="",
  1627. explanation=f"Parameter {name} is optional with a default value of {param.default}. This is not supported yet.",
  1628. hints=[
  1629. *graph_break_hints.SUPPORTABLE,
  1630. ],
  1631. )
  1632. def produce_matching(
  1633. debug_type: str, sources: Iterable[Any], candidates: Iterable[Any]
  1634. ) -> list[Optional[int]]:
  1635. matched_elements_positions: list[Optional[int]] = []
  1636. dict_of_source_vals = {}
  1637. for i, val in enumerate(sources):
  1638. dict_of_source_vals[id(val)] = i
  1639. for val in candidates:
  1640. if isinstance(val, tuple(common_constant_types)):
  1641. matched_elements_positions.append(None)
  1642. elif id(val) not in dict_of_source_vals:
  1643. if debug_type == "inputs":
  1644. check_optional_input_and_error(f_sig)
  1645. raise AssertionError(
  1646. f"Unexpectedly found a {type(val)} in the {debug_type}.\n"
  1647. 'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"',
  1648. )
  1649. else:
  1650. matched_elements_positions.append(dict_of_source_vals[id(val)])
  1651. return matched_elements_positions
  1652. matched_input_elements_positions = produce_matching(
  1653. "inputs", flat_args, graph_captured_input
  1654. )
  1655. assert graph_captured_output is not None
  1656. matched_output_elements_positions = produce_matching(
  1657. "outputs", list(graph_captured_output) + flat_args, flat_results_traced
  1658. )
  1659. new_graph = FlattenInputOutputSignature(
  1660. graph,
  1661. flat_args,
  1662. matched_input_elements_positions, # type: ignore[arg-type]
  1663. flat_results_traced,
  1664. matched_output_elements_positions, # type: ignore[arg-type]
  1665. example_fake_inputs,
  1666. flat_args_dynamic_dims,
  1667. fake_mode,
  1668. ).transform()
  1669. new_graph.graph._codegen = _PyTreeCodeGen(
  1670. _PyTreeInfo(
  1671. argument_names(f_sig, orig_args, orig_kwargs),
  1672. in_spec,
  1673. out_spec_traced,
  1674. )
  1675. )
  1676. new_graph.recompile()
  1677. return new_graph
  1678. def export(
  1679. f: Callable[..., Any],
  1680. *extra_args: Any,
  1681. aten_graph: bool = False,
  1682. pre_dispatch: bool = False,
  1683. decomposition_table: Optional[
  1684. dict[torch._ops.OpOverload, Callable[..., Any]]
  1685. ] = None,
  1686. tracing_mode: str = "symbolic",
  1687. dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
  1688. specialize_float: bool = True,
  1689. assume_static_by_default: bool = False,
  1690. same_signature: bool = True,
  1691. disable_constraint_solver: bool = False,
  1692. prefer_deferred_runtime_asserts_over_guards: bool = False,
  1693. _log_export_usage: bool = True,
  1694. constraints: Optional[list[Constraint]] = None,
  1695. **extra_kwargs: Any,
  1696. ) -> Callable[..., ExportResult]:
  1697. """
  1698. Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
  1699. Args:
  1700. f (callable): A PyTorch function to be exported.
  1701. aten_graph (bool): If True, exports a graph with ATen operators.
  1702. If False, exports a graph with Python operators. Default is False.
  1703. pre_dispatch (bool): If True, exports a graph with ATen operators,
  1704. but before any logic in the PyTorch dispatcher has run.
  1705. This can be useful if you want to apply further transformations on a graph before running it
  1706. through autograd, autocast, or any other functionalities that are integrated into the dispatcher.
  1707. This flag is only valid if aten_graph=True is set.
  1708. Default is False.
  1709. decomposition_table (dict): A dictionary that maps operators to their decomposition functions.
  1710. Required if aten_graph or tracing_mode is specified. Default is None.
  1711. tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic".
  1712. dynamic_shapes:
  1713. An optional argument where the type should either be:
  1714. 1) a dict from argument names of ``f`` to their dynamic shape specifications,
  1715. 2) a tuple that specifies dynamic shape specifications for each input in original order.
  1716. If you are specifying dynamism on keyword args, you will need to pass them in the order that
  1717. is defined in the original function signature.
  1718. The dynamic shape of a tensor argument can be specified as either
  1719. (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
  1720. not required to include static dimension indices in this dict, but when they are,
  1721. they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
  1722. where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
  1723. are denoted by None. Arguments that are dicts or tuples / lists of tensors are
  1724. recursively specified by using mappings or sequences of contained specifications.
  1725. same_signature (bool): If True, rewrite the returned graph's signature to be the same as f.
  1726. disable_constraint_solver (bool): Whether the dim constraint solver must be disabled.
  1727. Returns:
  1728. A function that given args and kwargs, returns a tuple of (graph, guards)
  1729. Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options.
  1730. Guards: The guards we accumulated during tracing f above
  1731. Raises:
  1732. AssertionError: If decomposition_table is specified without setting aten_graph=True,
  1733. or if graph breaks during tracing in export.
  1734. AssertionError: If Dynamo input and output is not consistent with traced input/output.
  1735. Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
  1736. """
  1737. if config.debug_force_graph_break_on_leaf_return:
  1738. raise unittest.SkipTest("Cannot force graph break on export")
  1739. if _log_export_usage:
  1740. log_export_usage(event="export.private_api", flags={"_dynamo"})
  1741. # Deal with "local variable referenced before assignment"
  1742. _f = f
  1743. _specialize_float = specialize_float
  1744. _assume_static_by_default = assume_static_by_default
  1745. _constraints = constraints
  1746. def inner(*args: Any, **kwargs: Any) -> ExportResult:
  1747. if not _constraints:
  1748. combined_args = _combine_args(_f, args, kwargs)
  1749. constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
  1750. else:
  1751. constraints = _constraints
  1752. f = _f
  1753. specialize_float = _specialize_float
  1754. assume_static_by_default = _assume_static_by_default
  1755. check_if_dynamo_supported()
  1756. torch._C._log_api_usage_once("torch._dynamo.export")
  1757. if decomposition_table is not None:
  1758. assert aten_graph, (
  1759. "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
  1760. )
  1761. if pre_dispatch:
  1762. assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
  1763. f = innermost_fn(f)
  1764. call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
  1765. original_signature = inspect.signature(call_to_inspect) # type: ignore[arg-type]
  1766. graph = None
  1767. out_guards = None
  1768. graph_captured_input = None
  1769. graph_captured_result: Optional[tuple[torch.Tensor, ...]] = None
  1770. fake_mode = None
  1771. result_traced = None
  1772. def guard_export_print(guards: _guards.GuardsSet) -> None:
  1773. nonlocal out_guards
  1774. assert out_guards is None, (
  1775. "whole graph export entails exactly one guard export"
  1776. )
  1777. out_guards = guards
  1778. example_inputs: list[Any] = []
  1779. def dynamo_normalization_capturing_compiler(
  1780. gm: torch.fx.GraphModule, inner_example_inputs: list[Any]
  1781. ) -> Callable[..., Any]:
  1782. nonlocal graph
  1783. assert graph is None, (
  1784. "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
  1785. )
  1786. graph = gm
  1787. nonlocal fake_mode, example_inputs
  1788. # NB: do NOT pass inner_example_inputs here, we are detecting the
  1789. # Dynamo allocated fake mode, which should be DISTINCT from a
  1790. # potential outer ambient fake mode which the user provided.
  1791. # example_inputs is always the user specified inputs, so they
  1792. # would have the wrong fake mode attached to them
  1793. fake_mode = _guards.detect_fake_mode()
  1794. example_inputs = inner_example_inputs
  1795. def result_capturing_wrapper(*graph_inputs: Any) -> Any:
  1796. nonlocal graph_captured_result
  1797. nonlocal graph_captured_input
  1798. graph_captured_input = graph_inputs
  1799. assert graph is not None
  1800. named_parameters = dict(graph.named_parameters(remove_duplicate=False))
  1801. named_buffers = dict(graph.named_buffers(remove_duplicate=False))
  1802. ambient_fake_mode = (
  1803. _guards.detect_fake_mode(graph_inputs)
  1804. if _guards.detect_fake_mode(graph_inputs) is not None
  1805. else fake_mode
  1806. )
  1807. # We reran fake tensor propagation, but we didn't do
  1808. # anything with the resulting unbacked SymInts. Drop them
  1809. # from the pending list.
  1810. # NB: this is wrong if graph_captured_result has
  1811. # data-dependent output size!
  1812. ignore_fresh_unbacked = null_context()
  1813. assert ambient_fake_mode is not None
  1814. if shape_env := ambient_fake_mode.shape_env:
  1815. ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() # type: ignore[assignment]
  1816. with (
  1817. ambient_fake_mode,
  1818. enable_python_dispatcher(),
  1819. ignore_fresh_unbacked,
  1820. ):
  1821. params_and_buffers = {
  1822. **named_parameters,
  1823. **named_buffers,
  1824. }
  1825. fake_params_buffers = {}
  1826. for name, value in params_and_buffers.items():
  1827. fake_params_buffers[name] = ambient_fake_mode.from_tensor(
  1828. value, static_shapes=True
  1829. )
  1830. from torch._export.non_strict_utils import (
  1831. key_path_to_source,
  1832. KeyPath,
  1833. )
  1834. def fakify_with_ambient(
  1835. path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any]
  1836. ) -> Any:
  1837. if isinstance(t, torch.Tensor):
  1838. # pyrefly: ignore [missing-attribute]
  1839. return ambient_fake_mode.from_tensor(t, static_shapes=True)
  1840. elif isinstance(t, _IntWrapper):
  1841. if (
  1842. t.dynamism is not None
  1843. and isinstance(t.dynamism, _DimHint)
  1844. and t.dynamism.type
  1845. in (
  1846. _DimHintType.DYNAMIC,
  1847. _DimHintType.AUTO,
  1848. )
  1849. ): # type: ignore[union-attr]
  1850. source = key_path_to_source(path)
  1851. symint = ambient_fake_mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr]
  1852. t.val, source, DimDynamic.DYNAMIC
  1853. )
  1854. return symint
  1855. else:
  1856. return t.val
  1857. else:
  1858. return t
  1859. fake_graph_inputs = pytree.tree_map_with_path(
  1860. fakify_with_ambient, graph_inputs
  1861. )
  1862. graph_captured_result = torch.func.functional_call(
  1863. graph,
  1864. fake_params_buffers, # type: ignore[arg-type]
  1865. fake_graph_inputs, # type: ignore[arg-type]
  1866. )
  1867. return graph_captured_result
  1868. return result_capturing_wrapper
  1869. # Note: This is needed by rewrite_signature. We need to put it before
  1870. # optimize_assert since user program may mutate the inputs.
  1871. flat_args, in_spec = pytree.tree_flatten((args, kwargs))
  1872. remove_from_cache(f)
  1873. constraint_violation_error = None
  1874. if tracing_mode != "symbolic":
  1875. assume_static_by_default = True
  1876. with (
  1877. config.patch(
  1878. specialize_int=True,
  1879. specialize_float=specialize_float,
  1880. assume_static_by_default=assume_static_by_default,
  1881. automatic_dynamic_shapes=False,
  1882. capture_dynamic_output_shape_ops=True,
  1883. capture_scalar_outputs=True,
  1884. constant_fold_autograd_profiler_enabled=True,
  1885. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
  1886. # install_free_tensors ensures that params and buffers are still
  1887. # added as graph attributes, and makes Dynamo emits graphs that
  1888. # follow export pytree-able input requirements
  1889. install_free_tensors=config.install_free_tensors_for_export,
  1890. ),
  1891. _compiling_state_context(),
  1892. ):
  1893. opt_f = optimize_assert(
  1894. dynamo_normalization_capturing_compiler,
  1895. hooks=Hooks(
  1896. guard_export_fn=guard_export_print,
  1897. guard_fail_fn=None,
  1898. ),
  1899. export=True,
  1900. export_constraints=constraints,
  1901. )(f)
  1902. # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
  1903. try:
  1904. result_traced = opt_f(*args, **kwargs)
  1905. except ConstraintViolationError as e:
  1906. constraint_violation_error = e
  1907. remove_from_cache(f)
  1908. if (
  1909. not disable_constraint_solver
  1910. and (shape_env := getattr(fake_mode, "shape_env", None)) is not None
  1911. and (dim_constraints := shape_env.dim_constraints) is not None
  1912. and not isinstance(
  1913. call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
  1914. )
  1915. and not trace_rules.check(call_to_inspect)
  1916. ):
  1917. dim_constraints.solve()
  1918. forced_specializations = dim_constraints.forced_specializations()
  1919. msg = dim_constraints.prettify_results(
  1920. original_signature,
  1921. dynamic_shapes,
  1922. constraint_violation_error,
  1923. forced_specializations,
  1924. )
  1925. if constraint_violation_error:
  1926. constraint_violation_error.args = (
  1927. constraint_violation_error.args[0] + msg,
  1928. )
  1929. else:
  1930. if forced_specializations:
  1931. constraint_violation_error = ConstraintViolationError(msg)
  1932. else:
  1933. log.info(
  1934. "Summary of dimension constraints:%s",
  1935. msg,
  1936. )
  1937. # Error if we have any constraints on static values
  1938. for k in shape_env.var_to_range:
  1939. if isinstance(k, sympy.Integer):
  1940. constraint_violation_error = ConstraintViolationError(
  1941. f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
  1942. "It appears that you're trying to set a constraint on a "
  1943. f"value which we evaluated to have a static value of {k}. "
  1944. 'Set TORCH_LOGS="+export" for more information.'
  1945. )
  1946. if constraint_violation_error:
  1947. raise constraint_violation_error
  1948. if graph is None:
  1949. assert same_signature, (
  1950. "Failed to produce a graph during tracing as no tensor operations were found and same_signature is False."
  1951. )
  1952. # If the module does not contain any tensor computation, we would create a graph with inputs and outputs.
  1953. # To be consistent with the graph traced by dynano, `graph` will have only tensor inputs as placeholders
  1954. # and tensor outputs as output nodes. non-tensor inputs and outputs will be added when rewriting signature.
  1955. # We will also construct the `example_inputs`, `graph_captured_input`, and `graph_captured_result` corresponding
  1956. # to `graph`.
  1957. example_inputs = []
  1958. graph_captured_input = ()
  1959. graph_captured_result = ()
  1960. fake_mode = torch._subclasses.FakeTensorMode(
  1961. shape_env=ShapeEnv(), export=True
  1962. )
  1963. if out_guards is None:
  1964. out_guards = _guards.GuardsSet()
  1965. assert out_guards is not None # suppress mypy error
  1966. parameter_names = list(original_signature.parameters.keys())
  1967. fx_graph = torch.fx.Graph()
  1968. for i, name in enumerate(parameter_names):
  1969. if torch.is_tensor(flat_args[i]):
  1970. node = fx_graph.placeholder(name)
  1971. node.meta["val"] = fake_mode.from_tensor(
  1972. flat_args[i], static_shapes=True
  1973. )
  1974. graph_captured_input = graph_captured_input + (flat_args[i],)
  1975. example_inputs.append(flat_args[i])
  1976. fx_graph.output(graph_captured_result)
  1977. module = torch.nn.Module()
  1978. graph = torch.fx.GraphModule(module, fx_graph)
  1979. log.info(
  1980. "Failed to capture a graph during tracing as no tensor operations were found.:\n\n%s",
  1981. graph.print_readable(print_output=False, colored=True),
  1982. )
  1983. else:
  1984. assert out_guards is not None, "Failed to produce guards during tracing"
  1985. assert fake_mode is not None
  1986. log.info(
  1987. "Dynamo captured graph:\n\n%s",
  1988. graph.print_readable(print_output=False, colored=True),
  1989. )
  1990. # This check need to happened before aten_graph
  1991. # because placeholder's _source_node attribute is not preserved by make_fx
  1992. if same_signature:
  1993. check_signature_rewritable(graph)
  1994. # NB: This is mostly hitting the cache; Dynamo already converted these
  1995. example_fake_inputs = [
  1996. fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
  1997. for t in example_inputs
  1998. ]
  1999. if aten_graph:
  2000. # Running graph with interpreter is needed for propagating the stack_trace
  2001. def graph_with_interpreter(*args: Any) -> Any:
  2002. with torch.fx.traceback.preserve_node_meta():
  2003. return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type]
  2004. with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
  2005. try:
  2006. graph = make_fx(
  2007. graph_with_interpreter,
  2008. decomposition_table=decomposition_table,
  2009. tracing_mode="real",
  2010. _allow_non_fake_inputs=True,
  2011. pre_dispatch=pre_dispatch,
  2012. _allow_fake_constant=False,
  2013. )(*example_fake_inputs)
  2014. except CondOpArgsMismatchError as e:
  2015. # Wrap the internal error to the user-facing error
  2016. raise UserError( # noqa: B904
  2017. UserErrorType.DYNAMIC_CONTROL_FLOW,
  2018. str(e),
  2019. case_name="cond_operands",
  2020. )
  2021. assert graph is not None
  2022. for node in graph.graph.find_nodes(op="get_attr"):
  2023. if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type]
  2024. node.meta["val"] = fake_mode.from_tensor(
  2025. getattr(graph, node.target), # type: ignore[arg-type]
  2026. static_shapes=True,
  2027. )
  2028. if same_signature:
  2029. flat_args_dynamic_dims = [
  2030. {
  2031. c.dim
  2032. for c in (constraints or ())
  2033. if (
  2034. c.t_id == id(x)
  2035. and not isinstance(c, _RelaxedConstraint)
  2036. and c.constraint_range.vr.lower != c.constraint_range.vr.upper
  2037. )
  2038. }
  2039. for x in flat_args
  2040. ]
  2041. graph = rewrite_signature(
  2042. original_signature,
  2043. graph,
  2044. fake_mode,
  2045. flat_args,
  2046. in_spec,
  2047. example_fake_inputs,
  2048. graph_captured_input, # type: ignore[arg-type]
  2049. graph_captured_result,
  2050. result_traced, # type: ignore[possibly-undefined]
  2051. flat_args_dynamic_dims,
  2052. )
  2053. return ExportResult(graph, out_guards)
  2054. if extra_args or extra_kwargs:
  2055. warnings.warn(
  2056. "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. "
  2057. "If you don't migrate, we may break your export call in the future if your user defined kwargs "
  2058. "conflict with future kwargs added to export(f).",
  2059. FutureWarning,
  2060. stacklevel=2,
  2061. )
  2062. return inner(*extra_args, **extra_kwargs) # type: ignore[return-value]
  2063. else:
  2064. return inner
  2065. def optimize_assert(*args: Any, **kwargs: Any) -> OptimizeContext:
  2066. if "rebuild_ctx" in kwargs and kwargs["rebuild_ctx"] is not None:
  2067. # called from optimize
  2068. rebuild_ctx = kwargs["rebuild_ctx"]
  2069. del kwargs["rebuild_ctx"]
  2070. else:
  2071. def rebuild_ctx() -> OptimizeContext:
  2072. return optimize_assert(*args, **kwargs)
  2073. return _optimize_assert(rebuild_ctx, *args, **kwargs)
  2074. def _optimize_assert(
  2075. rebuild_ctx: Callable[[], OptimizeContext],
  2076. backend: Union[str, Callable[..., Any], None],
  2077. *,
  2078. hooks: Hooks = Hooks(None, None, None),
  2079. export: bool = False,
  2080. export_constraints: Optional[Any] = None,
  2081. dynamic: Optional[bool] = None,
  2082. package: Optional[CompilePackage] = None,
  2083. ) -> OptimizeContext:
  2084. """
  2085. Guarantees single-graph capture.
  2086. The same as `torch._dynamo.optimize(backend)` but ignores
  2087. symbolic_convert.error_on_graph_break setting.
  2088. Used for fullgraph=True and export, since we must always error on graph breaks and ignore
  2089. symbolic_convert.error_on_graph_break. Can also be used for testing.
  2090. """
  2091. backend = get_compiler_fn(backend)
  2092. # Find if backend has any extra context manager
  2093. backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
  2094. if config.caching_precompile and package is None:
  2095. # Create an uninitialized package that will be set/filled by
  2096. # _OptimizeContext.__call__
  2097. # We need to instantiate the object here because the same CompilePackage
  2098. # needs to be shared between convert_frame_assert
  2099. # and OptimizeContext.
  2100. from .package import CompilePackage
  2101. package = CompilePackage(fn=None, dynamo=None, ignore_inlined_sources=False)
  2102. return _optimize_catch_errors(
  2103. convert_frame.convert_frame_assert(
  2104. # pyrefly: ignore [bad-argument-type]
  2105. backend,
  2106. export=export,
  2107. export_constraints=export_constraints,
  2108. package=package,
  2109. ),
  2110. hooks,
  2111. backend_ctx_ctor,
  2112. fullgraph=True,
  2113. export=export,
  2114. dynamic=dynamic,
  2115. rebuild_ctx=rebuild_ctx,
  2116. package=package,
  2117. )
  2118. class TorchPatcher:
  2119. @staticmethod
  2120. @functools.cache
  2121. def patch() -> None:
  2122. # A better way to disable the following would be decorate the source
  2123. # functions with @torch._disable_dynamo. However, this causes issues
  2124. # with torch.deploy internally.
  2125. from .decorators import disable
  2126. torch.jit.trace = disable(
  2127. torch.jit.trace, reason="tracing into TorchScript not fully supported"
  2128. )
  2129. torch.jit.trace_module = disable(
  2130. torch.jit.trace_module,
  2131. reason="tracing into TorchScript not fully supported",
  2132. )
  2133. torch.jit._get_trace_graph = disable(
  2134. torch.jit._get_trace_graph,
  2135. reason="tracing into TorchScript not fully supported",
  2136. )
  2137. torch.fx._symbolic_trace.Tracer.trace = disable(
  2138. torch.fx._symbolic_trace.Tracer.trace,
  2139. reason="tracing into FX not fully supported",
  2140. )
  2141. torch.distributions.Distribution.set_default_validate_args(False)
  2142. from torch.optim import (
  2143. adadelta,
  2144. adagrad,
  2145. adam,
  2146. adamax,
  2147. adamw,
  2148. asgd,
  2149. lbfgs,
  2150. nadam,
  2151. radam,
  2152. rmsprop,
  2153. rprop,
  2154. sgd,
  2155. sparse_adam,
  2156. )
  2157. optimizer_modules = {
  2158. adadelta,
  2159. adagrad,
  2160. adam,
  2161. adamax,
  2162. adamw,
  2163. asgd,
  2164. lbfgs,
  2165. nadam,
  2166. radam,
  2167. rmsprop,
  2168. rprop,
  2169. sgd,
  2170. sparse_adam,
  2171. }
  2172. for opt_mod in optimizer_modules:
  2173. opt_name = opt_mod.__name__.split(".")[-1]
  2174. fused_fn_name = f"_fused_{opt_name}"
  2175. if hasattr(opt_mod, fused_fn_name):
  2176. setattr(
  2177. opt_mod,
  2178. fused_fn_name,
  2179. disable(
  2180. getattr(opt_mod, fused_fn_name),
  2181. reason="don't trace into fused optimizer",
  2182. ),
  2183. )
  2184. optimizer_classes = [
  2185. opt
  2186. for opt in torch.optim.__dict__.values()
  2187. if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
  2188. ]
  2189. # Note: we don't support sparsity or tracing through backwards
  2190. excluded_optimizer_classes = {
  2191. torch.optim.SparseAdam,
  2192. torch.optim.LBFGS,
  2193. }
  2194. for opt in optimizer_classes:
  2195. if opt in excluded_optimizer_classes:
  2196. opt.step = disable(
  2197. opt.step, reason=f"optimizer {opt} step not supported"
  2198. )
  2199. if hasattr(opt, "_init_group"):
  2200. opt._init_group = disable(
  2201. opt._init_group, reason=f"optimizer {opt} _init_group not supported"
  2202. )
  2203. @staticmethod
  2204. def suppress_torch_distributed_warnings(
  2205. fn: Callable[..., Any],
  2206. ) -> Callable[..., Any]:
  2207. def inner_fn(*args: Any, **kwargs: Any) -> Any:
  2208. with torch._logging.hide_warnings(
  2209. torch._logging._internal.user_warning_filter
  2210. ):
  2211. return fn(*args, **kwargs)
  2212. return inner_fn
  2213. def skip_code(code: types.CodeType) -> None:
  2214. set_code_exec_strategy(
  2215. code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
  2216. )