convert_frame.py 86 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325
  1. """
  2. This module implements TorchDynamo's core frame conversion functionality, transforming Python
  3. frames into FX graphs. It handles:
  4. - Frame analysis and bytecode transformation
  5. - Guard creation and management for dynamic behaviors
  6. - Cache management for recompilation
  7. - Error handling and fallback mechanisms
  8. Key classes:
  9. - ConvertFrame: Main entry point for frame conversion with error handling
  10. - ConvertFrameAssert: Implements core frame to graph conversion logic
  11. - Tracker: Tracks input/output code objects during conversion
  12. - CatchErrorsWrapper: Provides error handling and suppression logic
  13. The conversion process preserves program semantics while enabling optimizations
  14. through torch.compile() and related systems.
  15. NOTE: _torchdynamo_orig_backend is used for convert frame wrappers to identify the inner wrapped function.
  16. By going down the _torchdynamo_orig_backend chain, one can recover the original unwrapped backend,
  17. which is checked for during the Dynamo cache lookup.
  18. """
  19. from __future__ import annotations
  20. import collections
  21. import contextlib
  22. import cProfile
  23. import dataclasses
  24. import dis
  25. import functools
  26. import gc
  27. import importlib
  28. import inspect
  29. import itertools
  30. import logging
  31. import os
  32. import pstats
  33. import random
  34. import subprocess
  35. import sys
  36. import tempfile
  37. import threading
  38. import time
  39. import traceback
  40. import types
  41. import typing
  42. import unittest.mock as mock
  43. import weakref
  44. from dataclasses import dataclass
  45. from pathlib import Path
  46. from types import CellType, CodeType, FunctionType, ModuleType
  47. from typing import Any, NoReturn, Optional, TypeVar, Union
  48. from typing_extensions import ParamSpec
  49. from weakref import ReferenceType
  50. import torch
  51. import torch._logging
  52. from torch._C._dynamo.guards import GlobalStateGuard
  53. from torch._dynamo.callback import CallbackTrigger
  54. from torch._dynamo.distributed import get_compile_pg
  55. from torch._dynamo.symbolic_convert import TensorifyState
  56. from torch._guards import compile_context, CompileContext, CompileId, tracing
  57. from torch._logging import structured
  58. from torch._utils_internal import (
  59. compile_time_strobelight_meta,
  60. maybe_upload_prof_stats_to_manifold,
  61. signpost_event,
  62. )
  63. from torch.fx._lazy_graph_module import _use_lazy_graph_module
  64. from torch.fx.experimental.symbolic_shapes import (
  65. ConstraintViolationError,
  66. GuardOnDataDependentSymNode,
  67. )
  68. from torch.fx.graph_module import _forward_from_src as original_forward_from_src
  69. from torch.monitor import _WaitCounter
  70. from torch.nn.parallel.distributed import DistributedDataParallel
  71. from torch.utils._ordered_set import OrderedSet
  72. from torch.utils._python_dispatch import (
  73. _disable_current_modes,
  74. any_torch_dispatch_mode_on_stack,
  75. is_in_any_mode_without_ignore_compile_internals,
  76. )
  77. from torch.utils._traceback import CapturedTraceback, format_traceback_short
  78. from . import config, decorators, exc, graph_break_hints, trace_rules
  79. from .backends.registry import _is_registered_backend
  80. from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
  81. from .bytecode_transformation import (
  82. check_inst_exn_tab_entries_valid,
  83. Instruction,
  84. is_generator,
  85. propagate_inst_exn_table_entries,
  86. transform_code_object,
  87. )
  88. from .cache_size import (
  89. CacheSizeRelevantForFrame,
  90. compute_cache_size,
  91. exceeds_recompile_limit,
  92. is_recompilation,
  93. )
  94. from .eval_frame import (
  95. always_optimize_code_objects,
  96. Constraint,
  97. dynamo_tls,
  98. innermost_backend,
  99. innermost_fn,
  100. skip_code,
  101. TorchPatcher,
  102. )
  103. from .exc import (
  104. augment_exc_message,
  105. BackendCompilerFailed,
  106. FailOnRecompileLimitHit,
  107. format_error_msg,
  108. InternalTorchDynamoError,
  109. PackageError,
  110. ResumePrologueTracingError,
  111. ShortenTraceback,
  112. TorchRuntimeError,
  113. UncapturedHigherOrderOpError,
  114. unimplemented,
  115. Unsupported,
  116. )
  117. from .graph_bytecode_inputs import reset_user_object_tracking
  118. from .guards import (
  119. CheckFunctionManager,
  120. get_and_maybe_log_recompilation_reasons,
  121. GuardedCode,
  122. )
  123. from .hooks import Hooks
  124. from .output_graph import DynamoTracerOutput, OutputGraphCommon
  125. from .pgo import (
  126. _log_size_mismatch_recompile,
  127. log_frame_dynamic_whitelist,
  128. put_code_state,
  129. )
  130. from .replay_record import ExecutionRecord
  131. from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
  132. from .symbolic_convert import (
  133. DistributedState,
  134. ExceptionStack,
  135. InstructionTranslator,
  136. LocalState,
  137. SpeculationLog,
  138. )
  139. from .trace_rules import is_numpy
  140. from .types import ConvertFrameReturn, FrameAction, FrameExecStrategy, wrap_guarded_code
  141. from .utils import (
  142. _get_error_on_graph_break,
  143. chromium_event_timed,
  144. CleanupManager,
  145. CompileTimeInstructionCounter,
  146. counters,
  147. dynamo_timed,
  148. format_bytecode,
  149. gen_record_file_name,
  150. get_hook_for_recompile_user_context,
  151. get_metrics_context,
  152. increment_frame,
  153. is_namedtuple,
  154. istype,
  155. LazyString,
  156. maybe_disable_inference_mode,
  157. maybe_disable_inference_mode_for_fake_prop,
  158. orig_code_map,
  159. reset_graph_break_dup_checker,
  160. setup_compile_debug,
  161. to_int_us,
  162. troubleshooting_url,
  163. write_record_to_file,
  164. )
  165. from .variables.torch_function import torch_function_mode_stack_state_mgr
  166. np: Optional[ModuleType]
  167. try:
  168. import numpy as np
  169. except ModuleNotFoundError:
  170. np = None
  171. if typing.TYPE_CHECKING:
  172. from collections.abc import Callable
  173. from torch.utils.weak import WeakIdKeyDictionary
  174. from .backends.registry import CompilerFn
  175. from .package import CompilePackage
  176. from .repro.after_dynamo import WrapBackendDebug
  177. from .types import BytecodeHook, CacheEntry, DynamoFrameType
  178. from .variables.builder import FrameStateSizeEntry
  179. log = logging.getLogger(__name__)
  180. bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
  181. graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
  182. compile_lock = threading.RLock()
  183. _T = TypeVar("_T")
  184. _P = ParamSpec("_P")
  185. class TODO_UNKNOWN:
  186. pass
  187. def _clear_fake_mode_weakrefs(
  188. fake_mode: Optional[torch._subclasses.fake_tensor.FakeTensorMode],
  189. ) -> None:
  190. """Clear WeakIdRef entries from a FakeTensorMode's describer."""
  191. if fake_mode is None:
  192. return
  193. describer = fake_mode.fake_tensor_converter.meta_converter.describer
  194. describer.lookup_tensor.clear()
  195. describer.lookup_storage.clear()
  196. class Tracker:
  197. def __init__(self) -> None:
  198. self.seen: list[ReferenceType[CodeType]] = []
  199. self.seen_ids: set[int] = set()
  200. def add(self, strong_obj: CodeType) -> None:
  201. idx = id(strong_obj)
  202. if idx not in self.seen_ids:
  203. obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx))
  204. self.seen.append(obj)
  205. self.seen_ids.add(idx)
  206. def __contains__(self, item: CodeType) -> bool:
  207. return id(item) in self.seen_ids
  208. def clear(self) -> None:
  209. self.seen.clear()
  210. self.seen_ids.clear()
  211. input_codes = Tracker()
  212. output_codes = Tracker()
  213. initial_global_state: Optional[GlobalStateGuard] = None
  214. @functools.wraps(original_forward_from_src)
  215. def fx_forward_from_src_skip_result(
  216. src: str, globals: dict[str, Any], co_fields: Optional[dict[str, str]] = None
  217. ) -> FunctionType:
  218. # we monkey patch FX to prevent infinite loop of trying to convert
  219. # our generated code
  220. result = original_forward_from_src(src, globals, co_fields)
  221. skip_code(result.__code__)
  222. return result
  223. def log_dynamo_start(code: CodeType, skip: int = 0) -> list[str]:
  224. convert_frame_intern = structured.intern_string(__file__)
  225. captured_tb = CapturedTraceback.extract(skip=4 + skip).summary()
  226. frames_interned = structured.from_traceback(captured_tb)
  227. # Extract and filter the stack
  228. stack = list(
  229. itertools.takewhile(
  230. lambda f: f["filename"] != convert_frame_intern,
  231. frames_interned,
  232. )
  233. ) + [
  234. {
  235. "line": code.co_firstlineno,
  236. "name": code.co_name,
  237. "filename": structured.intern_string(code.co_filename),
  238. }
  239. ]
  240. # Initialize the ChromiumEventLogger on start
  241. torch._logging.trace_structured(
  242. "dynamo_start",
  243. lambda: {"stack": stack},
  244. )
  245. # Capture stack separately without using from_traceback to get the actual filenames
  246. stack_strings = [
  247. f"Line: {frame.lineno}, Name: {frame.name}, Filename: {frame.filename}"
  248. for frame in captured_tb
  249. if frame.filename != convert_frame_intern
  250. ] + [
  251. f"Line: {code.co_firstlineno}, Name: {code.co_name}, Filename: {code.co_filename}"
  252. ]
  253. return stack_strings
  254. def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  255. """
  256. Context manager to:
  257. 1) Save/restore torch.is_grad_enabled() state
  258. 2) Save/restore python random state
  259. 3) Save/restore torch random state
  260. 4) Monkey patch torch.fx.graph_module._forward_from_src
  261. """
  262. @functools.wraps(fn)
  263. def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  264. guards = GlobalStateGuard()
  265. prior_grad_mode = torch.is_grad_enabled()
  266. # Just in case we get left in a bad dispatch state we want to restore
  267. # it. This can happen because the dispatch bits aren't a true
  268. # stack/counter - so we can't just increment/decrement them as we enter
  269. # and leave.
  270. with (
  271. torch._C._PreserveDispatchKeyGuard(),
  272. maybe_disable_inference_mode(),
  273. maybe_disable_inference_mode_for_fake_prop(),
  274. ):
  275. prior_inference_mode = torch.is_inference_mode_enabled()
  276. prior_deterministic = torch.are_deterministic_algorithms_enabled()
  277. prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
  278. prior_mobile_allocator_state = (
  279. torch._C._is_default_mobile_cpu_allocator_set()
  280. )
  281. py_rng_state = random.getstate()
  282. prior_dtype = torch.get_default_dtype()
  283. torch_rng_state = torch.random.get_rng_state()
  284. cuda_rng_state = None
  285. if torch.cuda.is_available():
  286. with torch._C.DisableTorchFunction():
  287. cuda_rng_state = torch.cuda.get_rng_state()
  288. cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter(
  289. "cuda", "matmul"
  290. )
  291. prior_fwd_from_src = torch.fx.graph_module._forward_from_src
  292. torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
  293. cleanup = setup_compile_debug()
  294. exit_stack = contextlib.ExitStack()
  295. exit_stack.enter_context(
  296. torch.fx._symbolic_trace._maybe_revert_all_patches()
  297. )
  298. exit_stack.enter_context(torch_function_mode_stack_state_mgr)
  299. reset_user_object_tracking()
  300. try:
  301. return fn(*args, **kwargs)
  302. finally:
  303. cleanup.close()
  304. assert torch._C._len_torch_function_stack() == 0, (
  305. "Torch function mode stack state changed while dynamo tracing, please report a bug"
  306. )
  307. exit_stack.close()
  308. torch._C._set_grad_enabled(prior_grad_mode)
  309. torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
  310. torch.use_deterministic_algorithms(
  311. prior_deterministic, warn_only=prior_warn_only
  312. )
  313. random.setstate(py_rng_state)
  314. torch.random.set_rng_state(torch_rng_state)
  315. torch.set_default_dtype(prior_dtype)
  316. curr_mobile_allocator_state = (
  317. torch._C._is_default_mobile_cpu_allocator_set()
  318. )
  319. if prior_mobile_allocator_state != curr_mobile_allocator_state:
  320. torch._C._unset_default_mobile_cpu_allocator()
  321. if cuda_rng_state is not None:
  322. with torch._C.DisableTorchFunction():
  323. torch.cuda.set_rng_state(cuda_rng_state)
  324. torch._C._set_fp32_precision_setter(
  325. "cuda", "matmul", cuda_matmul_fp32_prec
  326. )
  327. torch.fx.graph_module._forward_from_src = prior_fwd_from_src
  328. assert guards.check(), (
  329. f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"
  330. )
  331. _fn._torchdynamo_orig_backend = fn # type: ignore[attr-defined]
  332. return _fn
  333. @TorchPatcher.suppress_torch_distributed_warnings
  334. def has_tensor_in_frame(frame: DynamoFrameType) -> bool:
  335. """Check if the frame has torch.* related bits"""
  336. # Check if the function was decorated using torch._dynamo.optimize
  337. if frame.f_code in always_optimize_code_objects:
  338. return True
  339. # Check if there is global import of torch.*
  340. for co_name in frame.f_code.co_names:
  341. if co_name in frame.f_globals:
  342. obj = frame.f_globals[co_name]
  343. if isinstance(obj, ModuleType) and (
  344. obj.__name__.startswith("torch.") or obj is torch
  345. ):
  346. return True
  347. # ... or a global import of numpy.*
  348. if np and config.trace_numpy and (obj is np or is_numpy(obj)):
  349. return True
  350. seen_ids: dict[int, bool] = {}
  351. def has_tensor(obj: object) -> bool:
  352. """Recursively check if the obj has a tensor"""
  353. obj_id = id(obj)
  354. if obj_id in seen_ids:
  355. return seen_ids[obj_id]
  356. seen_ids[obj_id] = False
  357. if isinstance(obj, (torch.Tensor, torch.nn.Module)) or (
  358. istype(obj, type) and issubclass(obj, torch.nn.Module)
  359. ):
  360. seen_ids[obj_id] = True
  361. return seen_ids[obj_id]
  362. elif (
  363. config.trace_numpy
  364. and np
  365. and (istype(obj, np.ndarray) or isinstance(obj, np.generic))
  366. ):
  367. seen_ids[obj_id] = True
  368. return seen_ids[obj_id]
  369. elif istype(obj, (list, tuple)):
  370. seen_ids[obj_id] = any(has_tensor(v) for v in obj)
  371. return seen_ids[obj_id]
  372. elif istype(obj, dict):
  373. # Some packages like pytest can be updated during runtime. So, make a
  374. # copy of values to avoid issues like "RuntimeError: dictionary
  375. # changed size during iteration"
  376. values = list(obj.values())
  377. seen_ids[obj_id] = any(has_tensor(v) for v in values)
  378. return seen_ids[obj_id]
  379. elif istype(obj, (str, int, float, type(None), bool)):
  380. seen_ids[obj_id] = False
  381. return seen_ids[obj_id]
  382. elif is_namedtuple(obj) and hasattr(obj, "_fields"):
  383. seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields)
  384. return seen_ids[obj_id]
  385. else:
  386. # if config.debug:
  387. # print(
  388. # f"Assuming that object of type {type(obj)} does not have a tensor"
  389. # )
  390. return False
  391. # Check if the passed arguments are of type Tensor
  392. for value in frame.f_locals.values():
  393. if has_tensor(value):
  394. return True
  395. log.debug(
  396. "skipping because no torch.* %s \
  397. %s %s",
  398. frame.f_code.co_name,
  399. frame.f_code.co_filename,
  400. frame.f_code.co_firstlineno,
  401. )
  402. return False
  403. def exception_handler(
  404. e: Exception,
  405. code: CodeType,
  406. frame: Optional[DynamoFrameType] = None,
  407. export: bool = False,
  408. ) -> None:
  409. record_filename = None
  410. if hasattr(e, "exec_record"):
  411. record_filename = gen_record_file_name(e, code)
  412. write_record_to_file(record_filename, e.exec_record)
  413. e.record_filename = record_filename # type: ignore[attr-defined]
  414. augment_exc_message(e, export=export)
  415. FRAME_COUNTER = 0
  416. FRAME_COMPILE_COUNTER: typing.Counter[Union[int, FrameStateSizeEntry]] = (
  417. collections.Counter()
  418. )
  419. def maybe_cprofile(func: Callable[_P, _T]) -> Callable[_P, _T]:
  420. if config.cprofile:
  421. return cprofile_wrapper(func)
  422. return func
  423. def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
  424. @functools.wraps(func)
  425. def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  426. trace_id = CompileContext.current_trace_id()
  427. assert trace_id, "Trace id is None"
  428. profile_path = Path(
  429. os.path.join(
  430. tempfile.gettempdir(),
  431. f"{func.__name__}_{str(trace_id).replace('/', '_')}.profile",
  432. )
  433. )
  434. prof = cProfile.Profile()
  435. try:
  436. start_ts = time.time()
  437. # runcall calls prof.enable() and prof.disable(), so do NOT call
  438. # enable outside. This leads to issues like
  439. # ValueError: Another profiling tool is already active
  440. # pyrefly: ignore [bad-argument-type]
  441. retval = prof.runcall(func, *args, **kwargs)
  442. profile_latency = time.time() - start_ts
  443. except ValueError:
  444. log.exception("failed to enable cProfile")
  445. profile_latency = 0
  446. retval = func(*args, **kwargs)
  447. log.warning(
  448. "### Cprofile for %s trace id [%s] took %.3f seconds ###",
  449. func.__name__,
  450. trace_id,
  451. profile_latency,
  452. )
  453. ps = pstats.Stats(prof)
  454. try:
  455. prof.dump_stats(profile_path)
  456. except OSError:
  457. log.exception("Cannot write to %s", profile_path)
  458. log.warning("Raw profile at %s", profile_path)
  459. svg_path = profile_path.with_suffix(".svg")
  460. try:
  461. with subprocess.Popen(
  462. [
  463. "gprof2dot",
  464. "-f",
  465. "pstats",
  466. "--node-label=total-time-percentage",
  467. "--node-label=self-time-percentage",
  468. "--node-label=total-time",
  469. str(profile_path),
  470. ],
  471. stdout=subprocess.PIPE,
  472. ) as gprof2dot_process:
  473. subprocess.check_call(
  474. ["dot", "-Tsvg", "-o", str(svg_path)],
  475. stdin=gprof2dot_process.stdout,
  476. )
  477. log.warning("Generated SVG from profile at %s", svg_path)
  478. except FileNotFoundError:
  479. log.warning(
  480. "Failed to generate SVG from profile -- dumping stats instead."
  481. "Try installing gprof2dot and dot for a better visualization"
  482. )
  483. ps.sort_stats(pstats.SortKey.TIME).print_stats(20)
  484. ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20)
  485. if manifold_link := maybe_upload_prof_stats_to_manifold(
  486. str(profile_path)
  487. ): # fb-only
  488. torch._logging.trace_structured(
  489. "link",
  490. lambda: {"name": "cprofile_manifold_url", "url": manifold_link},
  491. )
  492. return retval
  493. return profile_wrapper
  494. @dataclass
  495. class ConvertFrameBox:
  496. error_on_graph_break: Optional[bool] = None
  497. def get_compile_id(
  498. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  499. ) -> CompileId:
  500. global FRAME_COUNTER
  501. if "_id" not in frame_state:
  502. frame_state["_id"] = FRAME_COUNTER
  503. FRAME_COUNTER += 1
  504. frame_id = frame_state["_id"]
  505. assert isinstance(frame_id, int)
  506. frame_compile_id = FRAME_COMPILE_COUNTER[frame_id]
  507. FRAME_COMPILE_COUNTER[frame_id] += 1
  508. compiled_autograd_id = None
  509. if prior := CompileContext.current_compile_id():
  510. compiled_autograd_id = prior.compiled_autograd_id
  511. return CompileId(
  512. compiled_autograd_id=compiled_autograd_id,
  513. frame_id=frame_id,
  514. frame_compile_id=frame_compile_id,
  515. )
  516. class ConvertFrameAssert:
  517. def __init__(
  518. self,
  519. compiler_fn: CompilerFn,
  520. one_graph: bool = True,
  521. export: bool = False,
  522. export_constraints: Any | None = None,
  523. package: CompilePackage | None = None,
  524. ) -> None:
  525. # assert export_constraints is None
  526. reset_graph_break_dup_checker()
  527. self._torchdynamo_orig_backend = compiler_fn
  528. self._one_graph = one_graph
  529. self._export = export
  530. self._export_constraints = export_constraints
  531. self._package = package
  532. self._box = ConvertFrameBox()
  533. @property
  534. def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]:
  535. return lambda backend: convert_frame_assert(
  536. backend,
  537. self._one_graph,
  538. self._export,
  539. self._export_constraints,
  540. )
  541. def __call__(
  542. self,
  543. frame: DynamoFrameType,
  544. cache_entry: Optional[CacheEntry],
  545. hooks: Hooks,
  546. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  547. *,
  548. skip: int = 0,
  549. ) -> ConvertFrameReturn:
  550. increment_frame()
  551. code = frame.f_code
  552. cache_size = compute_cache_size(frame, cache_entry)
  553. input_codes.add(code)
  554. if code in output_codes:
  555. return ConvertFrameReturn()
  556. if (
  557. os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION")
  558. and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name
  559. ):
  560. return ConvertFrameReturn()
  561. if code.co_name == "<genexpr>" and code.co_filename.endswith(
  562. (
  563. "transformers/file_utils.py",
  564. "transformers/utils/generic.py",
  565. "diffusers/utils/outputs.py",
  566. )
  567. ):
  568. # not needed, but cleans up torchbench error stats
  569. return ConvertFrameReturn()
  570. if code.co_name == "__setattr__":
  571. # setattr could be tricky to handle generally,
  572. # but also not likely useful to compile- skip the whole frame
  573. return ConvertFrameReturn()
  574. if code.co_name == "__init__" and code.co_filename.startswith(
  575. os.path.dirname(torch.optim.__file__)
  576. ):
  577. # optimizer support is still incomplete see
  578. # test_state_dict in test/dynamo/test_optimizers.py
  579. return ConvertFrameReturn()
  580. # Check if the frame is generated by an exec builtin call
  581. # TODO - Running exec generated frame seems propagates f_globals to the
  582. # next frames.
  583. if code.co_name == "<module>" and code.co_filename == "<string>":
  584. return ConvertFrameReturn()
  585. if (
  586. code.co_name == "<lambda>"
  587. and code.co_filename == "<string>"
  588. and not bool(frame.f_builtins)
  589. ):
  590. # namedtuple subclass constructor. Empty builtins cause issue with
  591. # len keyword in LIST_LEN guard.
  592. return ConvertFrameReturn()
  593. if is_generator(code):
  594. unimplemented(
  595. gb_type="Attempt to trace generator",
  596. context="",
  597. explanation="Generators cannot be compiled directly with `torch.compile`.",
  598. hints=[
  599. "Call a generator from inside of a non-generator Python function and "
  600. "compile that function instead.",
  601. *graph_break_hints.FUNDAMENTAL,
  602. ],
  603. )
  604. if not has_tensor_in_frame(frame):
  605. return ConvertFrameReturn()
  606. # skip tracing non-recursive disabled functions
  607. # detect if the previous frame (non-convert_frame) is a non-recursive disable wrapper
  608. prev_frame = sys._getframe()
  609. # pyrefly: ignore [bad-assignment]
  610. while (
  611. prev_frame
  612. and "torch/_dynamo/convert_frame.py" in prev_frame.f_code.co_filename
  613. ):
  614. prev_frame = prev_frame.f_back # type: ignore[assignment]
  615. if (
  616. prev_frame
  617. and prev_frame.f_code is decorators._nonrecursive_disable_wrapper_code
  618. ):
  619. return ConvertFrameReturn(apply_to_code=False)
  620. global initial_global_state
  621. # Save the previous initial_global_state to handle nested compilations
  622. # (e.g., compiled autograd running during graph execution can trigger
  623. # nested compilations that would otherwise overwrite the outer state)
  624. prev_initial_global_state = initial_global_state
  625. initial_global_state = GlobalStateGuard()
  626. compile_id = get_compile_id(frame_state)
  627. frame_id = compile_id.frame_id
  628. signpost_event(
  629. "dynamo",
  630. "_convert_frame_assert._compile",
  631. {
  632. "co_name": code.co_name,
  633. "frame_id": frame_id,
  634. "compile_id": str(compile_id),
  635. "co_filename": code.co_filename,
  636. "co_firstlineno": code.co_firstlineno,
  637. "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
  638. "accumulated_cache_size": cache_size.num_cache_entries,
  639. },
  640. )
  641. # Record traced frames, skipping Dynamo generated ones.
  642. if not code.co_name.startswith(TORCH_DYNAMO_RESUME_IN_PREFIX):
  643. info = f"{code.co_name} {code.co_filename}:{code.co_firstlineno}"
  644. dynamo_tls.traced_frame_infos.append(info)
  645. try:
  646. with compile_context(CompileContext(compile_id)):
  647. result = _compile(
  648. frame.f_code,
  649. frame.f_globals,
  650. frame.f_locals,
  651. frame.f_builtins,
  652. frame.closure,
  653. self._torchdynamo_orig_backend,
  654. self._one_graph,
  655. self._export,
  656. self._export_constraints,
  657. hooks,
  658. cache_entry,
  659. cache_size,
  660. frame,
  661. frame_state=frame_state,
  662. compile_id=compile_id,
  663. skip=skip + 1,
  664. package=self._package,
  665. convert_frame_box=self._box,
  666. )
  667. finally:
  668. # Restore the previous initial_global_state for nested compilation handling
  669. initial_global_state = prev_initial_global_state
  670. if config.caching_precompile and self._package is not None:
  671. from .package import DynamoCache
  672. # Record that the dynamo package has changed
  673. DynamoCache.record_package(self._package)
  674. return result
  675. def convert_frame_assert(
  676. compiler_fn: CompilerFn,
  677. one_graph: bool = True,
  678. export: bool = False,
  679. export_constraints: Any | None = None,
  680. package: Optional[CompilePackage] = None,
  681. ) -> ConvertFrameAssert:
  682. """Fully convert a frame into an FX graph, raising an exception if we fail."""
  683. return ConvertFrameAssert(
  684. compiler_fn, one_graph, export, export_constraints, package
  685. )
  686. from collections import OrderedDict
  687. from torch.utils.hooks import RemovableHandle
  688. # we have to use `OrderedDict` to make `RemovableHandle` work.
  689. _bytecode_hooks: dict[int, BytecodeHook] = OrderedDict()
  690. def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:
  691. """Register hooks for bytecode generated by Dynamo. The hook can do some
  692. logging, as well as return a new code object to be used. Please refer
  693. to `BytecodeHook` for the hook signature.
  694. """
  695. handle = RemovableHandle(_bytecode_hooks)
  696. _bytecode_hooks[handle.id] = hook
  697. return handle
  698. # TODO - We want to run preserve_node_meta context manager here, but the CI
  699. # fails (its unclear if the failures were flaky)
  700. # @torch.fx.traceback.preserve_node_meta()
  701. @preserve_global_state
  702. def trace_frame(
  703. code: types.CodeType,
  704. globals: dict[str, object],
  705. locals: dict[str, object],
  706. builtins: dict[str, object],
  707. closure: tuple[CellType],
  708. compiler_fn: CompilerFn,
  709. tf_mode_stack: list[torch.overrides.TorchFunctionMode],
  710. one_graph: bool,
  711. speculation_log: SpeculationLog,
  712. instructions: list[Instruction],
  713. code_options: dict[str, object],
  714. *,
  715. export: bool = False,
  716. export_constraints: Any | None = None,
  717. frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
  718. distributed_state: Optional[DistributedState] = None,
  719. package: Optional[CompilePackage] = None,
  720. ) -> DynamoTracerOutput:
  721. from torch.fx.experimental.validator import bisect, translation_validation_enabled
  722. speculation_log.restart() # type: ignore[has-type]
  723. exn_vt_stack = ExceptionStack()
  724. tracer = InstructionTranslator(
  725. instructions,
  726. code,
  727. locals,
  728. globals,
  729. builtins,
  730. closure,
  731. tf_mode_stack,
  732. code_options,
  733. compiler_fn,
  734. one_graph,
  735. export,
  736. export_constraints,
  737. frame_state=frame_state,
  738. speculation_log=speculation_log, # type: ignore[has-type]
  739. exn_vt_stack=exn_vt_stack,
  740. distributed_state=distributed_state, # type: ignore[has-type]
  741. package=package,
  742. )
  743. def run_tracer() -> None:
  744. try:
  745. tracer.output.mark_bytecode_tracing_start()
  746. with tracing(tracer.output.tracing_context), tracer.set_current_tx():
  747. tracer.run()
  748. except exc.UnspecializeRestartAnalysis:
  749. speculation_log.clear() # type: ignore[has-type]
  750. raise
  751. except (
  752. exc.SpeculationRestartAnalysis,
  753. exc.TensorifyScalarRestartAnalysis,
  754. exc.SkipFrame,
  755. ):
  756. raise
  757. except Exception:
  758. if translation_validation_enabled():
  759. bisect(tracer.output.shape_env)
  760. raise
  761. finally:
  762. tracer.output.call_cleanup_hooks()
  763. tracer.f_locals = {}
  764. try:
  765. run_tracer()
  766. tracer_output = DynamoTracerOutput(tracer)
  767. output = tracer_output.output_graph
  768. assert output is not None
  769. assert output.output_instructions
  770. instructions[:] = output.output_instructions
  771. code_options.update(output.code_options)
  772. propagate_inst_exn_table_entries(instructions)
  773. check_inst_exn_tab_entries_valid(instructions)
  774. instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
  775. except Exception as e:
  776. e._torch_dynamo_tracer_output = DynamoTracerOutput(tracer, error=True) # type: ignore[attr-defined]
  777. raise
  778. return tracer_output
  779. @dataclass
  780. class DynamoOutput:
  781. """
  782. Represents the core data returned from a single dynamo run, including:
  783. - Guards, wrapped inside tracer_output.output_graph.guards
  784. - Generated bytecode
  785. - Other information needed for compilation.
  786. This data structure should capture all the "interesting" information dynamo
  787. produces on the frontend side before it enters user backend.
  788. """
  789. tracer_output: DynamoTracerOutput
  790. bytecode: types.CodeType
  791. last_attempt_start_time: Optional[float]
  792. def build_guards(
  793. self,
  794. code: types.CodeType,
  795. hooks: Optional[Hooks] = None,
  796. save: bool = False,
  797. cache_entry: Optional[CacheEntry] = None,
  798. strict_error: bool = False,
  799. ) -> CheckFunctionManager:
  800. output_graph = self.tracer_output.output_graph
  801. assert output_graph is not None
  802. return CheckFunctionManager(
  803. code,
  804. output_graph,
  805. cache_entry,
  806. hooks.guard_fail_fn if hooks else None,
  807. hooks.guard_filter_fn if hooks else None,
  808. save_guards=save,
  809. strict_error=strict_error,
  810. )
  811. def graph_capture_output(
  812. self,
  813. argdefs: Optional[tuple[Any, ...]] = None,
  814. kwdefaults: Optional[dict[str, Any]] = None,
  815. ) -> GraphCaptureOutput:
  816. output_graph = self.tracer_output.output_graph
  817. assert output_graph is not None
  818. return GraphCaptureOutput(
  819. OutputGraphCommon(
  820. output_graph.dump_guards_state(),
  821. output_graph.import_sources,
  822. output_graph.shape_env,
  823. output_graph.export_metadata,
  824. output_graph.tracked_fakes_id_to_source,
  825. ),
  826. output_graph.import_sources,
  827. output_graph.traced_code,
  828. self.bytecode,
  829. self.tracer_output.closure,
  830. argdefs,
  831. kwdefaults,
  832. self.tracer_output.f_globals,
  833. )
  834. @dataclass
  835. class BackendInput:
  836. """
  837. Represents core data structure that dynamo will pass to a backend, including:
  838. - Graph module
  839. - Example inputs
  840. - The FakeTensorMode used for compiling graph.
  841. This data structure should capture all the information dynamo produces
  842. on for the user backend.
  843. """
  844. backend_id: str
  845. graph_module: torch.fx.GraphModule
  846. example_inputs: Any
  847. fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
  848. tensor_to_context: WeakIdKeyDictionary
  849. @dataclass(frozen=True)
  850. class GraphRuntimeEnv:
  851. bytecode: types.CodeType
  852. import_sources: dict[str, str]
  853. used_globals: dict[str, Any]
  854. closure: Optional[tuple[Any, ...]]
  855. argdefs: Optional[tuple[Any, ...]]
  856. kwdefaults: Optional[dict[str, Any]] = None
  857. external_refs: set[str] = dataclasses.field(default_factory=set)
  858. def forward_callable(
  859. self,
  860. backend_id: str,
  861. compiled_fn: Callable[..., Any],
  862. *,
  863. extra_globals: Optional[dict[str, Any]] = None,
  864. ) -> Callable[..., Any]:
  865. import_sources = {
  866. alias: importlib.import_module(module_name)
  867. for alias, module_name in self.import_sources.items()
  868. }
  869. f_globals = {
  870. **import_sources,
  871. **self.used_globals,
  872. **(extra_globals or {}),
  873. backend_id: compiled_fn,
  874. }
  875. # check that all external references are available
  876. self._check_external_refs(f_globals)
  877. fn = types.FunctionType(
  878. self.bytecode,
  879. f_globals,
  880. closure=self.closure,
  881. argdefs=self.argdefs,
  882. )
  883. if self.kwdefaults:
  884. fn.__kwdefaults__ = self.kwdefaults
  885. return fn
  886. def _check_external_refs(self, f_globals: dict[str, Any]) -> None:
  887. # pyrefly: ignore [implicit-any]
  888. missing_refs = []
  889. for ref in self.external_refs:
  890. if ref not in f_globals:
  891. missing_refs.append(ref)
  892. if missing_refs:
  893. raise RuntimeError(
  894. f"Missing required external references: {missing_refs}. "
  895. "Please load AOT compiled function with `f_globals=<enclosing global scope>`"
  896. )
  897. @dataclass
  898. class GraphCaptureOutput:
  899. """
  900. Minimal version of DynamoOutput
  901. """
  902. output_graph: OutputGraphCommon
  903. import_sources: dict[str, str]
  904. traced_code: list[CodeType]
  905. bytecode: CodeType
  906. closure: Optional[tuple[Any, ...]]
  907. argdefs: Optional[tuple[Any, ...]]
  908. kwdefaults: Optional[dict[str, Any]]
  909. f_globals: dict[str, Any]
  910. def build_guards(
  911. self,
  912. code: types.CodeType,
  913. hooks: Optional[Hooks] = None,
  914. save: bool = False,
  915. cache_entry: Optional[CacheEntry] = None,
  916. strict_error: bool = False,
  917. ) -> CheckFunctionManager:
  918. return CheckFunctionManager(
  919. code,
  920. self.output_graph,
  921. cache_entry,
  922. hooks.guard_fail_fn if hooks else None,
  923. hooks.guard_filter_fn if hooks else None,
  924. save_guards=save,
  925. strict_error=strict_error,
  926. )
  927. def get_runtime_env(self) -> GraphRuntimeEnv:
  928. from torch._dynamo.source import get_global_source_name
  929. used_globals = {}
  930. for (
  931. source
  932. ) in self.output_graph.export_metadata.graph_input_idx_to_local_source.values():
  933. global_name = get_global_source_name(source)
  934. if global_name is None:
  935. continue
  936. if global_name in self.f_globals:
  937. used_globals[global_name] = self.f_globals[global_name]
  938. # Scan bytecode for all external references
  939. external_refs = self._get_external_refs(self.bytecode)
  940. return GraphRuntimeEnv(
  941. bytecode=self.bytecode,
  942. import_sources=self.import_sources,
  943. used_globals=used_globals,
  944. closure=self.closure,
  945. argdefs=self.argdefs,
  946. kwdefaults=self.kwdefaults,
  947. external_refs=external_refs,
  948. )
  949. @staticmethod
  950. def _get_external_refs(bytecode: types.CodeType) -> set[str]:
  951. import dis
  952. external_refs: set[str] = set()
  953. # Get all instructions from the bytecode
  954. for instruction in dis.get_instructions(bytecode):
  955. # LOAD_GLOBAL loads a global variable or a builtin
  956. if instruction.opname == "LOAD_GLOBAL":
  957. if instruction.argval:
  958. external_refs.add(instruction.argval)
  959. # LOAD_NAME loads a name (used in module-level code, less common in functions)
  960. elif instruction.opname == "LOAD_NAME":
  961. if instruction.argval:
  962. external_refs.add(instruction.argval)
  963. return external_refs
  964. @dataclass
  965. class CaptureOutput:
  966. """
  967. CaptureOutput should represent all the information produced from torch
  968. compiler for a single graph capture. This intends to be consumed by
  969. various compiler frontends so that we can share as much compiler internals
  970. as possible and avoid great divergence between different stacks.
  971. This data structure should eventually contain all the information compiler
  972. produces as more refactors happens to converge different compiler
  973. frontends.
  974. """
  975. graph_capture_output: GraphCaptureOutput
  976. # BackendInput can be None when dynamo didn't compile any graph (no tensor op)
  977. backend_input: Optional[BackendInput]
  978. def forward_callable(
  979. self,
  980. *,
  981. compiled_fn: Optional[Callable[..., Any]] = None,
  982. extra_globals: Optional[dict[str, Any]] = None,
  983. ) -> Callable[..., Any]:
  984. runtime_env = self.graph_capture_output.get_runtime_env()
  985. assert self.backend_input is not None
  986. backend_id = self.backend_input.backend_id
  987. # pyrefly: ignore [bad-assignment, not-callable]
  988. compiled_fn = compiled_fn or self.backend_input.graph_module
  989. return runtime_env.forward_callable(
  990. backend_id,
  991. compiled_fn, # pyrefly: ignore [bad-argument-type]
  992. extra_globals=extra_globals,
  993. )
  994. def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
  995. """
  996. Utility function to get the function to trace, and optionally a bound self
  997. object, from a callable (nn.Module, function, or method).
  998. """
  999. import inspect
  1000. if isinstance(mod, torch.nn.Module):
  1001. resolved_forward = mod.forward
  1002. if hasattr(resolved_forward, "__self__"):
  1003. # pyrefly: ignore [missing-attribute]
  1004. resolved_forward = resolved_forward.__func__
  1005. resolved_call = mod.__call__
  1006. if hasattr(resolved_call, "__self__"):
  1007. # pyrefly: ignore [missing-attribute]
  1008. resolved_call = resolved_call.__func__
  1009. # Mirrored from NNModuleVariable.call_function:
  1010. # https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/nn_module.py#L1035
  1011. if (
  1012. len(mod._forward_pre_hooks) == 0
  1013. and len(mod._forward_hooks) == 0
  1014. and len(torch.nn.modules.module._global_forward_pre_hooks) == 0
  1015. and len(torch.nn.modules.module._global_forward_hooks) == 0
  1016. and len(mod._backward_pre_hooks) == 0
  1017. and len(mod._backward_hooks) == 0
  1018. and len(torch.nn.modules.module._global_backward_pre_hooks) == 0
  1019. and len(torch.nn.modules.module._global_backward_hooks) == 0
  1020. and resolved_forward != torch.nn.Module.forward # has forward impl
  1021. and resolved_call == torch.nn.Module.__call__ # no custom __call__ impl
  1022. ):
  1023. # We cannot trace __call__ by default because it will break
  1024. # the legacy dynamo export. If we want to revisit this,
  1025. # feel free to remove this path and try unittests in
  1026. # test_strict_export_v2.py
  1027. mod = mod.forward
  1028. elif isinstance(mod, torch.fx.GraphModule):
  1029. mod = mod._call_impl
  1030. else:
  1031. mod = mod.__call__
  1032. if hasattr(mod, "__self__"):
  1033. return mod.__func__, mod.__self__
  1034. elif inspect.isfunction(mod):
  1035. return mod, None
  1036. else:
  1037. raise RuntimeError(f"Unsupported model code type {mod}")
  1038. def _get_signature(fn: Any) -> inspect.Signature:
  1039. return inspect.signature(fn, follow_wrapped=False)
  1040. def _get_frame(
  1041. mod: Any,
  1042. args: tuple[Any, ...],
  1043. kwargs: Optional[dict[str, Any]] = None,
  1044. ) -> FrameInfo:
  1045. """
  1046. Create a frame to trace, given a model, args, and optional kwargs.
  1047. """
  1048. import builtins
  1049. fn, self_opt = get_traced_fn(mod)
  1050. if self_opt is not None:
  1051. args = (self_opt,) + args
  1052. if kwargs is None:
  1053. kwargs = {}
  1054. signature = _get_signature(fn)
  1055. bound_arguments = signature.bind(*args, **kwargs)
  1056. bound_arguments.apply_defaults()
  1057. f_locals = bound_arguments.arguments
  1058. closure = fn.__closure__ or ()
  1059. freevars = fn.__code__.co_freevars
  1060. if freevars or closure:
  1061. assert len(closure) == len(freevars)
  1062. f_locals.update(
  1063. {name: cell.cell_contents for name, cell in zip(freevars, closure)}
  1064. )
  1065. return FrameInfo(
  1066. fn.__code__,
  1067. fn.__globals__,
  1068. f_locals,
  1069. builtins.__dict__,
  1070. closure=fn.__closure__ or (), # type: ignore[arg-type]
  1071. argdefs=fn.__defaults__,
  1072. kwdefaults=fn.__kwdefaults__,
  1073. )
  1074. def fullgraph_capture(
  1075. mod: Any,
  1076. args: tuple[Any, ...],
  1077. kwargs: Optional[dict[str, Any]] = None,
  1078. *,
  1079. constraints: Optional[list[Constraint]] = None,
  1080. _is_export_deprecated_do_not_use: bool = False,
  1081. ) -> CaptureOutput:
  1082. """
  1083. This API captures a full graph for a model, given example inputs to trace with.
  1084. Specifically, it takes a callable (nn.Module, method, or function), args, and
  1085. optional kwargs, and returns Dynamo-captured graph along with other important
  1086. compile-time information. This serves as the common graph-capture mechanism
  1087. for different torch compiler AOT frontends (e.g. AOT precompile, export).
  1088. Note that this API doesn't apply context managers like metrics context,
  1089. and the expectation is that the caller will apply them depending
  1090. on the use case.
  1091. The CaptureOutput is separated into two parts:
  1092. 1. Frontend specific information, which includes:
  1093. - guards
  1094. - generated bytecode
  1095. - other information tracked by OutputGraphCommon.
  1096. 2. Backend specific information (indexed by unique backend id) such as:
  1097. - fx graph
  1098. - example inputs
  1099. """
  1100. frame = _get_frame(mod, args, kwargs)
  1101. with compile_context(CompileContext(get_compile_id({}))):
  1102. return _fullgraph_capture_frame(
  1103. frame,
  1104. constraints=constraints,
  1105. _is_export_deprecated_do_not_use=_is_export_deprecated_do_not_use,
  1106. )
  1107. @dataclass
  1108. class FrameInfo:
  1109. code: types.CodeType
  1110. globals: dict[str, object]
  1111. locals: dict[str, object]
  1112. builtins: dict[str, object]
  1113. closure: tuple[CellType]
  1114. argdefs: Optional[tuple[Any, ...]]
  1115. kwdefaults: Optional[dict[str, Any]]
  1116. def _fullgraph_capture_frame(
  1117. frame: FrameInfo,
  1118. *,
  1119. constraints: Optional[list[Constraint]] = None,
  1120. _is_export_deprecated_do_not_use: bool = False,
  1121. ) -> CaptureOutput:
  1122. from torch._guards import TracingContext
  1123. backend_input: Optional[BackendInput] = None
  1124. def fullgraph_compiler(
  1125. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  1126. ) -> torch.fx.GraphModule:
  1127. nonlocal backend_input
  1128. tracing_context = TracingContext.get()
  1129. fake_mode = tracing_context.fake_mode
  1130. tensor_to_context = tracing_context.tensor_to_context
  1131. assert fake_mode is not None
  1132. assert isinstance(gm.meta["backend_id"], str)
  1133. backend_input = BackendInput(
  1134. gm.meta["backend_id"], gm, example_inputs, fake_mode, tensor_to_context
  1135. )
  1136. return gm
  1137. try:
  1138. dynamo_output = compile_frame(
  1139. frame.code,
  1140. frame.globals,
  1141. frame.locals,
  1142. frame.builtins,
  1143. frame.closure,
  1144. # pyrefly: ignore [bad-argument-type]
  1145. compiler_fn=fullgraph_compiler,
  1146. export=_is_export_deprecated_do_not_use,
  1147. export_constraints=constraints, # type: ignore[arg-type]
  1148. one_graph=True,
  1149. restart_reasons=set(),
  1150. )
  1151. # https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/eval_frame.py#L831
  1152. except (Unsupported, UncapturedHigherOrderOpError) as e:
  1153. augment_exc_message(e)
  1154. if config.verbose:
  1155. raise
  1156. # strip internal tracebacks from causes
  1157. cur_exn: BaseException = e
  1158. while cur_exn.__cause__ is not None:
  1159. cur_exn.__cause__.with_traceback(None)
  1160. cur_exn = cur_exn.__cause__
  1161. raise e.with_traceback(None) from e.__cause__ # User compiler error
  1162. return CaptureOutput(
  1163. dynamo_output.graph_capture_output(frame.argdefs, frame.kwdefaults),
  1164. backend_input,
  1165. )
  1166. # Called by eval_frame_cpp.cpp in order to raise an error if Dynamo attempts to compile_frame
  1167. def get_fail_callback(callback: ConvertFrameProtocol) -> ConvertFrameProtocol:
  1168. fail_callback = getattr(callback, "_dynamo_fail_callback", None)
  1169. if fail_callback is not None:
  1170. return fail_callback
  1171. def compile_frame_error(*args: Any, **kwargs: Any) -> NoReturn:
  1172. raise RuntimeError(
  1173. "Dynamo: expected not to compile nested code - this happens because "
  1174. "a Dynamo callback was triggered and succeeded in compiling "
  1175. "when running fullgraph=True compiled code."
  1176. )
  1177. def fail_callback(*args: Any, **kwargs: Any) -> ConvertFrameReturn:
  1178. with mock.patch(__name__ + ".compile_frame", compile_frame_error):
  1179. return callback(*args, **kwargs)
  1180. # pyrefly: ignore [missing-attribute]
  1181. callback._dynamo_fail_callback = fail_callback
  1182. return fail_callback
  1183. def compile_frame( # type: ignore[return]
  1184. code: types.CodeType,
  1185. globals: dict[str, object],
  1186. locals: dict[str, object],
  1187. builtins: dict[str, object],
  1188. closure: tuple[CellType],
  1189. compiler_fn: CompilerFn,
  1190. one_graph: bool,
  1191. restart_reasons: set[str],
  1192. *,
  1193. export: bool = False,
  1194. export_constraints: Any | None = None,
  1195. frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
  1196. distributed_state: Optional[DistributedState] = None,
  1197. package: Optional[CompilePackage] = None,
  1198. # pyrefly: ignore [bad-return]
  1199. ) -> DynamoOutput:
  1200. """
  1201. A helper function taking a frame and backend, then return the generated bytecode
  1202. and guards as a common data structure.
  1203. This is a shared interface for multiple compiler frontends (e.g. torch.compile,
  1204. torch.export) that needs to capture a graph out of python code.
  1205. """
  1206. # This is shared across restarts
  1207. speculation_log = SpeculationLog()
  1208. def transform(
  1209. instructions: list[Instruction], code_options: dict[str, object]
  1210. ) -> DynamoTracerOutput:
  1211. tf_mode_stack: list[torch.overrides.TorchFunctionMode] = (
  1212. torch.overrides._get_current_function_mode_stack()
  1213. )
  1214. tracer_output = trace_frame(
  1215. code,
  1216. globals,
  1217. locals,
  1218. builtins,
  1219. closure,
  1220. compiler_fn,
  1221. tf_mode_stack,
  1222. one_graph,
  1223. speculation_log,
  1224. instructions,
  1225. code_options,
  1226. export=export,
  1227. export_constraints=export_constraints,
  1228. frame_state=frame_state,
  1229. distributed_state=distributed_state,
  1230. package=package,
  1231. )
  1232. assert tracer_output is not None
  1233. return tracer_output
  1234. last_attempt_start_time = None
  1235. for attempt in itertools.count():
  1236. CompileContext.get().attempt = attempt
  1237. try:
  1238. with dynamo_timed(f"compile_attempt_{attempt}", log_pt2_compile_event=True):
  1239. bytecode, tracer_output = transform_code_object(code, transform)
  1240. assert tracer_output is not None
  1241. return DynamoOutput(
  1242. tracer_output=tracer_output,
  1243. bytecode=bytecode,
  1244. last_attempt_start_time=last_attempt_start_time,
  1245. )
  1246. except exc.RestartAnalysis as e:
  1247. if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
  1248. TensorifyState.clear()
  1249. log.info(
  1250. "Restarting analysis due to %s",
  1251. LazyString(format_traceback_short, e.__traceback__),
  1252. )
  1253. # Clean up the failed tracer output's graph to break reference cycles
  1254. failed_tracer_output = getattr(e, "_torch_dynamo_tracer_output", None)
  1255. if failed_tracer_output:
  1256. failed_tracer_output._cleanup_output_graph()
  1257. # If restart reason is None just log the type of the exception
  1258. restart_reasons.add(e.restart_reason or str(type(e)))
  1259. # We now have a new "last attempt", reset the clock
  1260. last_attempt_start_time = time.time()
  1261. if attempt > 100:
  1262. unimplemented(
  1263. gb_type="Excessive RestartAnalysis() calls",
  1264. context="",
  1265. explanation="Dynamo attempted to trace the same frame 100+ times. "
  1266. "Giving up on compiling as the compile time tradeoff is likely not "
  1267. "worth the performance gain.",
  1268. hints=[],
  1269. )
  1270. except exc.SkipFrame as e:
  1271. if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
  1272. TensorifyState.clear()
  1273. # Clean up the failed tracer output's graph to break reference cycles
  1274. failed_tracer_output = getattr(e, "_torch_dynamo_tracer_output", None)
  1275. if failed_tracer_output:
  1276. failed_tracer_output._cleanup_output_graph()
  1277. log.debug( # noqa: G200
  1278. "Received signal to skip frame (without graph break): %s %s \
  1279. %s %s",
  1280. e,
  1281. code.co_name,
  1282. code.co_filename,
  1283. code.co_firstlineno,
  1284. )
  1285. raise
  1286. def _compile(
  1287. code: CodeType,
  1288. globals: dict[str, object],
  1289. locals: dict[str, object],
  1290. builtins: dict[str, object],
  1291. closure: tuple[CellType],
  1292. compiler_fn: CompilerFn,
  1293. one_graph: bool,
  1294. export: bool,
  1295. export_constraints: Any | None,
  1296. hooks: Hooks,
  1297. cache_entry: Optional[CacheEntry],
  1298. cache_size: CacheSizeRelevantForFrame,
  1299. frame: Optional[DynamoFrameType] = None,
  1300. frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
  1301. *,
  1302. compile_id: CompileId,
  1303. skip: int = 0,
  1304. package: Optional[CompilePackage] = None,
  1305. # Can be used to record things for the caller, both
  1306. # in the case of normal and exception code paths
  1307. convert_frame_box: Optional[ConvertFrameBox] = None,
  1308. ) -> ConvertFrameReturn:
  1309. from torch.fx.experimental.validator import (
  1310. BisectValidationException,
  1311. ValidationException,
  1312. )
  1313. # Only nonlocal defs here please!
  1314. # Time spent compiling this frame before restarting or failing analysis
  1315. dynamo_time_before_restart: float = 0.0
  1316. @compile_time_strobelight_meta(phase_name="compile_inner")
  1317. def compile_inner(
  1318. code: CodeType, one_graph: bool, hooks: Hooks
  1319. ) -> tuple[ConvertFrameReturn, Optional[DynamoTracerOutput]]:
  1320. with contextlib.ExitStack() as stack:
  1321. stack.enter_context(
  1322. torch._dynamo.callback_handler.install_callbacks(
  1323. CallbackTrigger.DYNAMO, str(CompileContext.current_compile_id())
  1324. )
  1325. )
  1326. stack.enter_context(CompileTimeInstructionCounter.record())
  1327. return _compile_inner(code, one_graph, hooks)
  1328. return (
  1329. ConvertFrameReturn(),
  1330. None,
  1331. ) # dead, but see https://github.com/python/mypy/issues/7577
  1332. @maybe_cprofile
  1333. def _compile_inner(
  1334. code: CodeType,
  1335. one_graph: bool,
  1336. hooks: Hooks,
  1337. ) -> tuple[ConvertFrameReturn, DynamoTracerOutput]:
  1338. nonlocal dynamo_time_before_restart
  1339. last_attempt_start_time = start_time = time.time()
  1340. def log_bytecode(
  1341. prefix: str, name: str, filename: str, line_no: int, code: CodeType
  1342. ) -> None:
  1343. if bytecode_log.isEnabledFor(logging.DEBUG):
  1344. bytecode_log.debug(
  1345. format_bytecode(prefix, name, filename, line_no, code)
  1346. )
  1347. log_bytecode(
  1348. "ORIGINAL BYTECODE",
  1349. code.co_name,
  1350. code.co_filename,
  1351. code.co_firstlineno,
  1352. code,
  1353. )
  1354. out_code = None
  1355. try:
  1356. dynamo_output = compile_frame(
  1357. code,
  1358. globals,
  1359. locals,
  1360. builtins,
  1361. closure,
  1362. compiler_fn,
  1363. one_graph,
  1364. restart_reasons,
  1365. export=export,
  1366. export_constraints=export_constraints,
  1367. frame_state=frame_state,
  1368. distributed_state=distributed_state,
  1369. package=package,
  1370. )
  1371. except exc.SkipFrame as e:
  1372. if one_graph:
  1373. log.debug("No graph captured with export/fullgraph=True")
  1374. assert e._torch_dynamo_tracer_output is not None
  1375. return ConvertFrameReturn(), e._torch_dynamo_tracer_output
  1376. assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type]
  1377. "compiler collective wasn't run before compilation completed"
  1378. )
  1379. out_code = dynamo_output.bytecode
  1380. tracer_output = dynamo_output.tracer_output
  1381. if dynamo_output.last_attempt_start_time is not None:
  1382. last_attempt_start_time = dynamo_output.last_attempt_start_time
  1383. assert out_code is not None
  1384. log_bytecode(
  1385. "MODIFIED BYTECODE",
  1386. code.co_name,
  1387. code.co_filename,
  1388. code.co_firstlineno,
  1389. out_code,
  1390. )
  1391. for idx, hook in enumerate(_bytecode_hooks.values()):
  1392. with dynamo_timed(f"bytecode_hooks_{idx}", log_pt2_compile_event=True):
  1393. hook_output = hook(code, out_code)
  1394. if hook_output is not None:
  1395. out_code = hook_output
  1396. orig_code_map[out_code] = code
  1397. output_codes.add(out_code)
  1398. dynamo_time_before_restart = last_attempt_start_time - start_time
  1399. assert tracer_output.output_graph is not None
  1400. output = tracer_output.output_graph
  1401. # Tests for new code objects.
  1402. # The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c
  1403. # Only test once the code object is created.
  1404. # They are not tested during runtime.
  1405. def count_args(code: CodeType) -> int:
  1406. import inspect
  1407. return (
  1408. code.co_argcount
  1409. + code.co_kwonlyargcount
  1410. + bool(code.co_flags & inspect.CO_VARARGS)
  1411. + bool(code.co_flags & inspect.CO_VARKEYWORDS)
  1412. )
  1413. assert out_code is not None
  1414. total_argcount_old = count_args(code)
  1415. total_argcount_new = count_args(out_code)
  1416. msg = "arg mismatch: "
  1417. msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, "
  1418. msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}"
  1419. assert (
  1420. code.co_varnames[:total_argcount_old]
  1421. == out_code.co_varnames[:total_argcount_new]
  1422. ), msg
  1423. msg = "free var mismatch: "
  1424. msg += f"old code object has free var {code.co_freevars}, "
  1425. msg += f"new code object has free var {out_code.co_freevars}"
  1426. assert code.co_freevars == out_code.co_freevars, msg
  1427. msg = "cell var mismatch: "
  1428. msg += f"old code object has cell var {code.co_cellvars}, "
  1429. msg += f"new code object has cell var {out_code.co_cellvars}"
  1430. assert code.co_cellvars == out_code.co_cellvars, msg
  1431. # Skipping Dynamo on a frame without any extracted graph.
  1432. # This does not affect eager functionality. But this is necessary
  1433. # for export for cases where Dynamo-reconstructed bytecode can create
  1434. # new function frames, confusing export in thinking that there
  1435. # are extra graphs now.
  1436. if output.export and output.is_empty_graph():
  1437. return ConvertFrameReturn(), tracer_output
  1438. assert output.guards is not None
  1439. CleanupManager.instance[out_code] = output.cleanups
  1440. nonlocal cache_entry
  1441. with dynamo_timed("build_guards", log_pt2_compile_event=True):
  1442. check_fn = dynamo_output.build_guards(
  1443. code,
  1444. hooks=hooks,
  1445. save=package is not None,
  1446. cache_entry=cache_entry,
  1447. )
  1448. if package is not None:
  1449. assert check_fn.guards_state is not None
  1450. package.add_guarded_code(check_fn.guards_state, out_code)
  1451. package.add_inlined_source(output.tracing_context.traced_code)
  1452. package.update_device_type(output.current_tracer.graph)
  1453. compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
  1454. annotation_str = "Torch-Compiled Region: " + compile_id_str
  1455. guarded_code = GuardedCode(
  1456. out_code,
  1457. check_fn.guard_manager, # type: ignore[arg-type]
  1458. compile_id,
  1459. annotation_str,
  1460. )
  1461. if not output.is_empty_graph() and hooks.guard_export_fn is not None:
  1462. # We should not run the guard_export_fn when Dynamo does not
  1463. # generate any graph. This can happen in export when TorchDynamo
  1464. # generated bytecode has some reconstruction logic for mutated
  1465. # variables which can trigger TorchDynamo on the children frames but
  1466. # they are benign and do not generate any new graphs.
  1467. hooks.guard_export_fn(output.guards)
  1468. return wrap_guarded_code(guarded_code), tracer_output
  1469. metrics_context = get_metrics_context()
  1470. code_context = (
  1471. package.code_context(code) if package is not None else contextlib.nullcontext()
  1472. )
  1473. with (
  1474. _use_lazy_graph_module(config.use_lazy_graph_module),
  1475. compile_context(CompileContext(compile_id)),
  1476. chromium_event_timed(
  1477. "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
  1478. ),
  1479. _WaitCounter("pytorch.wait_counter.entire_forward_compile").guard(),
  1480. metrics_context,
  1481. dynamo_timed(
  1482. "_compile.compile_inner",
  1483. phase_name="entire_frame_compile",
  1484. dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
  1485. ),
  1486. code_context,
  1487. ):
  1488. restart_reasons: set[str] = set()
  1489. if compile_pg := get_compile_pg():
  1490. distributed_state = DistributedState(compile_pg, LocalState())
  1491. else:
  1492. distributed_state = None
  1493. # Check recompilations
  1494. recompile_reason: Optional[str] = None
  1495. if is_recompilation(cache_size) and frame:
  1496. reasons = get_and_maybe_log_recompilation_reasons(
  1497. cache_entry, frame, innermost_fn(compiler_fn)
  1498. )
  1499. recompile_reason = (
  1500. "Unable to find recompilation reasons" if not reasons else reasons[0]
  1501. )
  1502. # Recheck for recompilation, for when inline_inbuilt_nn_modules is set to False
  1503. inline_inbuilt_nn_modules_candidate = False
  1504. if not config.inline_inbuilt_nn_modules and frame:
  1505. inbuilt_nn_reasons = get_and_maybe_log_recompilation_reasons(
  1506. cache_entry, frame, innermost_fn(compiler_fn), skip_logging=True
  1507. )
  1508. inbuilt_nn_recompile_reason = (
  1509. None if not inbuilt_nn_reasons else inbuilt_nn_reasons[0]
  1510. )
  1511. if (
  1512. inbuilt_nn_recompile_reason is not None
  1513. and "[inline-inbuilt-nn-modules-candidate]"
  1514. in inbuilt_nn_recompile_reason
  1515. ):
  1516. inline_inbuilt_nn_modules_candidate = True
  1517. # Set if the recompile is a candidate for inline_inbuilt_nn_modules
  1518. # regardless of whether inline_inbuilt_nn_modules is set or not
  1519. metrics_context.update_outer(
  1520. {
  1521. "recompile_reason": recompile_reason,
  1522. "inline_inbuilt_nn_modules_candidate": inline_inbuilt_nn_modules_candidate,
  1523. }
  1524. )
  1525. recompile_user_contexts = get_hook_for_recompile_user_context()
  1526. if recompile_user_contexts:
  1527. # cap each user context to N chars for data retention purposes. N=256
  1528. # is chosen to be large enough to capture the most important info.
  1529. user_contexts_msg = {
  1530. user_context()[:256] for user_context in recompile_user_contexts
  1531. }
  1532. metrics_context.set("recompile_user_contexts", user_contexts_msg)
  1533. exceeded, limit_type = exceeds_recompile_limit(cache_size, compile_id)
  1534. if exceeded:
  1535. def format_func_info(code: CodeType) -> str:
  1536. return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})"
  1537. # NS: Don't add period at the end of string, as it'll be added to URL
  1538. # rendering it incorrect
  1539. log.warning(
  1540. "torch._dynamo hit config.%s (%s)\n"
  1541. " function: %s\n"
  1542. " last reason: %s\n"
  1543. 'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n'
  1544. "To diagnose recompilation issues, see %s",
  1545. limit_type,
  1546. getattr(config, limit_type),
  1547. format_func_info(code),
  1548. recompile_reason,
  1549. troubleshooting_url,
  1550. )
  1551. def raise_unimplemented_cache_limit_exceeded() -> NoReturn:
  1552. unimplemented(
  1553. gb_type="Dynamo recompile limit exceeded",
  1554. context=f"Limit type: {limit_type}",
  1555. explanation="Dynamo attempted to recompile the code object too many times, "
  1556. f"exceeding the {limit_type} cache size limit (currently set to {getattr(config, limit_type)}). "
  1557. "Excessive recompilations can degrade "
  1558. "performance due to the compilation overhead of each recompilation.",
  1559. hints=[
  1560. "To monitor recompilations, enable TORCH_LOGS=recompiles. "
  1561. "If recompilations are expected, consider "
  1562. f"increasing torch._dynamo.config.{limit_type} to an appropriate value.",
  1563. f"See {troubleshooting_url} for tips on dealing with recompilations.",
  1564. ],
  1565. )
  1566. try:
  1567. raise_unimplemented_cache_limit_exceeded()
  1568. except Unsupported as e:
  1569. if config.fail_on_recompile_limit_hit:
  1570. raise FailOnRecompileLimitHit(
  1571. "Hard failure due to fail_on_recompile_limit_hit"
  1572. ) from e
  1573. elif one_graph:
  1574. raise FailOnRecompileLimitHit(
  1575. "Hard failure due to fullgraph=True"
  1576. ) from e
  1577. else:
  1578. # Set frame execution strategy to RUN_ONLY for this recompile limit case
  1579. e.frame_exec_strategy = FrameExecStrategy(
  1580. FrameAction.RUN_ONLY, FrameAction.RUN_ONLY
  1581. )
  1582. raise
  1583. log.debug(
  1584. "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s",
  1585. code.co_name,
  1586. code.co_filename,
  1587. code.co_firstlineno,
  1588. skip + 2,
  1589. # -2: omit current frame, omit contextlib decorator
  1590. "".join(CapturedTraceback.extract(skip=2 + skip).format()),
  1591. )
  1592. # -4: -2 as above, plus trace_structured frames
  1593. #
  1594. # NB: the frame looks like this:
  1595. #
  1596. # # handled by skip argument
  1597. # torch/_dynamo/convert_frame.py:1069 in catch_errors
  1598. # torch/_dynamo/convert_frame.py:910 in _convert_frame
  1599. # torch/_dynamo/convert_frame.py:464 in _convert_frame_assert
  1600. # torch/_utils_internal.py:70 in wrapper_function
  1601. #
  1602. # # 2 current frame and context lib
  1603. # env/lib/python3.10/contextlib.py:79 in inner
  1604. # torch/_dynamo/convert_frame.py:776 in _compile
  1605. #
  1606. # # 2 extra here
  1607. # torch/_logging/_internal.py:1064 in trace_structured
  1608. # torch/_dynamo/convert_frame.py:780 in <lambda>
  1609. stack_trace = log_dynamo_start(code, skip)
  1610. start_time_ns = time.time_ns()
  1611. fail_type: Optional[str] = None
  1612. fail_reason: Optional[str] = None
  1613. exception_stack_trace: Optional[list[str]] = None
  1614. fail_user_frame_filename: Optional[str] = None
  1615. fail_user_frame_lineno: Optional[int] = None
  1616. torch._dynamo.utils.ReinplaceCounters.clear()
  1617. guarded_code = None
  1618. tracer_output = None
  1619. try:
  1620. guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
  1621. # NB: We only put_code_state in success case. Success case here
  1622. # does include graph breaks; specifically, if a graph break still
  1623. # resulted in a partially compiled graph, we WILL return here. An
  1624. # Unsupported exception will only bubble to the top level if we
  1625. # are unable to compile the frame at all. In this case, there's
  1626. # no point in uploading the code state, because we will always
  1627. # fail exactly the same way even without the update. (It's useful
  1628. # to upload for graph break though, because this can prevent
  1629. # extra graph break compilations.)
  1630. put_code_state()
  1631. if (
  1632. tracer_output
  1633. and (output_graph := tracer_output.output_graph)
  1634. and output_graph.has_outputs()
  1635. ):
  1636. log_frame_dynamic_whitelist(code)
  1637. if recompile_reason and "size mismatch at index" in recompile_reason:
  1638. _log_size_mismatch_recompile()
  1639. return guarded_code
  1640. except Exception as e:
  1641. # NB: e's msg is mutated here to add user stack, but we DON'T want
  1642. # that stack in the Scuba logged fail_reason. So we grab the fail
  1643. # info here and add it to the metrics context below.
  1644. fail_type = type(e).__qualname__
  1645. fail_reason = str(e)
  1646. exception_stack_trace = [traceback.format_exc()]
  1647. exception_handler(e, code, frame, export=export)
  1648. # NB: this is the post-mutation exception
  1649. torch._logging.trace_structured(
  1650. "artifact",
  1651. metadata_fn=lambda: {
  1652. "name": "dynamo_error",
  1653. "encoding": "string",
  1654. },
  1655. payload_fn=lambda: traceback.format_exc(),
  1656. )
  1657. fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message(
  1658. e, compile_id
  1659. )
  1660. tracer_output = getattr(e, "_torch_dynamo_tracer_output", None)
  1661. if isinstance(
  1662. e,
  1663. (
  1664. Unsupported,
  1665. TorchRuntimeError,
  1666. BackendCompilerFailed,
  1667. AssertionError,
  1668. ConstraintViolationError,
  1669. GuardOnDataDependentSymNode,
  1670. ValidationException,
  1671. UncapturedHigherOrderOpError,
  1672. BisectValidationException,
  1673. ShortenTraceback,
  1674. PackageError,
  1675. ResumePrologueTracingError,
  1676. ),
  1677. ):
  1678. raise
  1679. else:
  1680. # Rewrap for clarity
  1681. raise InternalTorchDynamoError(
  1682. f"{type(e).__qualname__}: {str(e)}"
  1683. ).with_traceback(e.__traceback__) from None
  1684. finally:
  1685. # === WARNING WARNING WARNING ===
  1686. # If you commit a bug here, it will suppress writing to
  1687. # dynamo_compile table, and we will not have telemetry.
  1688. # Be extra careful when making changes here!
  1689. if torch._dynamo.config.run_gc_after_compile:
  1690. with dynamo_timed("gc", dynamo_compile_column_us="gc_time_us"):
  1691. log.info("run_gc_after_compile: running gc")
  1692. gc.collect(1)
  1693. output = None
  1694. if tracer_output:
  1695. output = tracer_output.output_graph
  1696. if output:
  1697. # pyrefly: ignore [implicit-any]
  1698. output.local_scope = {}
  1699. # tracer should already be None, keep an extra check here just in case.
  1700. if tracer := output.root_tx:
  1701. # pyrefly: ignore [implicit-any]
  1702. tracer.f_locals = {}
  1703. from .utils import curr_frame
  1704. frame_key = str(curr_frame)
  1705. if fail_reason is None and output is not None:
  1706. guard_count = len(output.guards)
  1707. shape_env_guard_count = len(output.shape_env.guards)
  1708. graph_op_count = output.count_calls()
  1709. graph_node_count = len(output.graph.nodes)
  1710. graph_node_shapes = output.get_graph_sizes_structured()
  1711. graph_input_count = len(output.placeholders)
  1712. non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops}
  1713. compliant_custom_ops = {
  1714. op.__qualname__ for op in output.compliant_custom_ops
  1715. }
  1716. torch._dynamo.utils.ReinplaceCounters.log()
  1717. else:
  1718. guard_count = None
  1719. shape_env_guard_count = None
  1720. graph_op_count = None
  1721. graph_node_count = None
  1722. # pyrefly: ignore [implicit-any]
  1723. graph_node_shapes = {}
  1724. graph_input_count = None
  1725. non_compliant_ops = set({})
  1726. compliant_custom_ops = set({})
  1727. restart_reasons = set()
  1728. # If compilation failed, the entire time is wasted
  1729. dynamo_time_before_restart = (time.time_ns() - start_time_ns) / 1e9
  1730. metrics = {
  1731. "frame_key": frame_key,
  1732. "co_name": code.co_name,
  1733. "co_filename": code.co_filename,
  1734. "co_firstlineno": code.co_firstlineno,
  1735. "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
  1736. "accumulated_cache_size": cache_size.num_cache_entries,
  1737. "guard_count": guard_count,
  1738. "shape_env_guard_count": shape_env_guard_count,
  1739. "graph_op_count": graph_op_count,
  1740. "graph_node_count": graph_node_count,
  1741. "graph_input_count": graph_input_count,
  1742. "fail_type": fail_type,
  1743. "fail_reason": fail_reason,
  1744. "fail_user_frame_filename": fail_user_frame_filename,
  1745. "fail_user_frame_lineno": fail_user_frame_lineno,
  1746. "non_compliant_ops": non_compliant_ops,
  1747. "compliant_custom_ops": compliant_custom_ops,
  1748. "restart_reasons": restart_reasons,
  1749. "dynamo_time_before_restart_s": dynamo_time_before_restart,
  1750. "has_guarded_code": guarded_code is not None,
  1751. "specialize_float": config.specialize_float,
  1752. "is_forward": True,
  1753. "dynamo_compile_time_before_restart_us": to_int_us(
  1754. dynamo_time_before_restart
  1755. ),
  1756. "stack_trace": stack_trace,
  1757. "graph_node_shapes": str(graph_node_shapes),
  1758. "exception_stack_trace": exception_stack_trace,
  1759. }
  1760. # TODO: replace with CompileEventLogger.compilation_metrics
  1761. # There are some columns here not in PT2 Compile Events
  1762. # so we need to slightly change it
  1763. metrics_context.update_outer(metrics)
  1764. # === END WARNING WARNING WARNING ===
  1765. # If tracer is available, then tracer.error_on_graph_break reflects value of
  1766. # global symbolic_convert.error_on_graph_break at the time of the graph break -
  1767. # symbolic_convert.error_on_graph_break may have been (correctly) changed during cleanup.
  1768. # If tracer is unavailable, then fallback to symbolic_convert.error_on_graph_break.
  1769. if convert_frame_box:
  1770. convert_frame_box.error_on_graph_break = (
  1771. tracer_output.error_on_graph_break
  1772. if tracer_output
  1773. else _get_error_on_graph_break()
  1774. )
  1775. # Cleanup guards unless if in export, which will return guards
  1776. # Make sure to to do this after collecting metrics
  1777. if (
  1778. tracer_output is not None
  1779. and tracer_output.output_graph is not None
  1780. and not tracer_output.output_graph.export
  1781. ):
  1782. tracer_output.output_graph.tracing_context.guards_context.dynamo_guards.inner = OrderedSet()
  1783. # Clear WeakIdRef entries that can block swap_tensors after compile.
  1784. # Determine whether to clear based on config and backend type.
  1785. should_clear = config.invalidate_compile_context_weakrefs
  1786. if should_clear is None:
  1787. # Default: clear for registered backends, don't clear for custom
  1788. # Unwrap the compiler_fn to get the actual backend function
  1789. should_clear = _is_registered_backend(innermost_backend(compiler_fn))
  1790. if should_clear:
  1791. if tracer_output and tracer_output.output_graph:
  1792. tc = tracer_output.output_graph.tracing_context
  1793. tc.tensor_to_context.clear()
  1794. # Clear both the current fake_mode and the old_fake_mode
  1795. # (the original is stored before backend_fake_mode replaces it)
  1796. _clear_fake_mode_weakrefs(tc.fake_mode)
  1797. if hasattr(tracer_output.output_graph, "_old_fake_mode"):
  1798. _clear_fake_mode_weakrefs(
  1799. tracer_output.output_graph._old_fake_mode
  1800. )
  1801. class ConvertFrame:
  1802. def __init__(
  1803. self,
  1804. compiler_fn: CompilerFn,
  1805. hooks: Hooks,
  1806. package: Optional[CompilePackage] = None,
  1807. ) -> None:
  1808. self._torchdynamo_orig_backend = compiler_fn
  1809. self._inner_convert = convert_frame_assert(
  1810. compiler_fn, one_graph=False, package=package
  1811. )
  1812. self._hooks = hooks
  1813. @property
  1814. def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]:
  1815. return lambda backend: convert_frame(
  1816. # pyrefly: ignore [bad-argument-type]
  1817. backend,
  1818. self._hooks,
  1819. )
  1820. def __call__(
  1821. self,
  1822. frame: DynamoFrameType,
  1823. cache_entry: Optional[CacheEntry],
  1824. hooks: Hooks,
  1825. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  1826. skip: int = 0,
  1827. ) -> ConvertFrameReturn:
  1828. input_codes.add(frame.f_code)
  1829. counters["frames"]["total"] += 1
  1830. try:
  1831. result = self._inner_convert(
  1832. frame, cache_entry, hooks, frame_state, skip=skip + 1
  1833. )
  1834. counters["frames"]["ok"] += 1
  1835. return result
  1836. except Exception as e:
  1837. # Do not allow errors to be suppressed if we're tracing a resume function prologue
  1838. if isinstance(e, ResumePrologueTracingError):
  1839. raise
  1840. error_on_graph_break = (
  1841. self._inner_convert._box.error_on_graph_break is not None
  1842. )
  1843. assert error_on_graph_break is not None
  1844. if self._inner_convert._box.error_on_graph_break:
  1845. # NOTE we _might_ have to wrap the current in a custom exception
  1846. # in order to correctly bubble up to the top-level compile wrapper in
  1847. # eval_frame.py. But re-raising seems to work for now because exceptions from tracing
  1848. # a nested call that results in a top-level frame compile will be handled by the caller
  1849. # as an observed exception - we don't expect that exception to be suppressed.
  1850. raise
  1851. # These two exception types are "soft" failure, in the sense that
  1852. # we know this is due to something we didn't implement all the
  1853. # way, scare the user less about it. That being said, if you
  1854. # are trying to understand why a graph break happened, it's still
  1855. # important to have this information, so offer it.
  1856. #
  1857. # NB: NotImplementedError used to be on this list, but actually
  1858. # it is impossible for it to reach here, as it is converted into
  1859. # InternalTorchDynamoError. This behavior seemed reasonable
  1860. # to me (ezyang, Aug 2023) so I kept it, but maybe at some point
  1861. # someone wanted these to also get suppressed. If so, you'll
  1862. # need to make these exceptions not get wrapped
  1863. # We intentionally don't want to suppress error here.
  1864. if isinstance(e, UncapturedHigherOrderOpError):
  1865. raise
  1866. soft_fail = isinstance(e, Unsupported)
  1867. code = frame.f_code
  1868. # Log soft failure that was not already logged by symbolic_convert.
  1869. # This happens e.g. for graph breaks that are raised in convert_frame.py
  1870. # TODO(williamwen42) Unsupported exn's from tracing are handled and logged by symbolic_convert.py
  1871. # Unsupported exn's caught here should be from convert_frame.py - figure out a better way
  1872. # to log these.
  1873. if (
  1874. soft_fail
  1875. and not getattr(e, "logged", False)
  1876. and graph_break_log.isEnabledFor(logging.DEBUG)
  1877. ):
  1878. # Log this message in the graph break. Also use the string
  1879. # "skip: " to tell that the whole frame is falling back to
  1880. # eager.
  1881. if hasattr(e, "compile_id") and hasattr(e, "real_stack"):
  1882. with compile_context(CompileContext(e.compile_id)): # type: ignore[attr-defined]
  1883. user_stack = e.real_stack
  1884. user_stack_formatted = "".join(
  1885. traceback.format_list(user_stack)
  1886. )
  1887. frame_info = exc.format_frame_info(code)
  1888. user_stack_trace = (
  1889. "Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
  1890. f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
  1891. "The graph break occurred in the following user code:\n"
  1892. f"{user_stack_formatted}"
  1893. )
  1894. torch._logging.trace_structured(
  1895. "artifact",
  1896. metadata_fn=lambda: {
  1897. "name": "dynamo_graph_break_reason",
  1898. "encoding": "string",
  1899. },
  1900. payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc()}",
  1901. )
  1902. graph_break_log.debug(
  1903. user_stack_trace,
  1904. exc_info=True,
  1905. stack_info=config.verbose,
  1906. )
  1907. if not config.suppress_errors and not soft_fail:
  1908. raise
  1909. # Suppress the error. NB: It's very important to do the
  1910. # suppression logging HERE, where the actual suppression
  1911. # happens. Previously it was somewhere else and so it was
  1912. # possible to accidentally not log at all.
  1913. record_filename = getattr(e, "record_filename", None)
  1914. code = frame.f_code
  1915. error_msg = format_error_msg(e, code, record_filename, frame)
  1916. if soft_fail:
  1917. log.info(error_msg, exc_info=True)
  1918. else:
  1919. log.warning(error_msg, exc_info=True)
  1920. # Check if the exception has a specific frame execution strategy
  1921. if (
  1922. isinstance(e, exc.TorchDynamoException)
  1923. and e.frame_exec_strategy is not None
  1924. ):
  1925. return ConvertFrameReturn(frame_exec_strategy=e.frame_exec_strategy)
  1926. return ConvertFrameReturn()
  1927. def convert_frame(
  1928. compiler_fn: CompilerFn,
  1929. hooks: Hooks,
  1930. package: Optional[CompilePackage] = None,
  1931. ) -> ConvertFrame:
  1932. """Try to convert a frame into an FX graph, if error leave frame unmodified"""
  1933. return ConvertFrame(compiler_fn, hooks, package=package)
  1934. # TODO mlazos: add support for same args, or record them
  1935. def replay(filename: str) -> None:
  1936. from .backends.debugging import eager
  1937. original_replay_val = config.replay_record_enabled
  1938. config.replay_record_enabled = False
  1939. with open(filename, "rb") as in_file:
  1940. record = ExecutionRecord.load(in_file)
  1941. record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
  1942. with decorators.error_on_graph_break(False):
  1943. try:
  1944. _compile(
  1945. record.code,
  1946. record.globals,
  1947. record.locals,
  1948. record.builtins,
  1949. record.closure,
  1950. compiler_fn=eager,
  1951. one_graph=False,
  1952. export=False,
  1953. export_constraints=None,
  1954. hooks=Hooks(),
  1955. cache_size=CacheSizeRelevantForFrame(0, 0),
  1956. cache_entry=None,
  1957. frame=None,
  1958. frame_state={},
  1959. compile_id=CompileId(frame_id=42, frame_compile_id=999),
  1960. )
  1961. finally:
  1962. config.replay_record_enabled = original_replay_val
  1963. def first_real_inst_idx(code: CodeType) -> int:
  1964. if sys.version_info < (3, 11):
  1965. return 0
  1966. for inst in dis.get_instructions(code):
  1967. if inst.opname == "RESUME":
  1968. return inst.offset // 2
  1969. raise RuntimeError("RESUME instruction not found in code")
  1970. class ConvertFrameProtocol(typing.Protocol):
  1971. def __call__(
  1972. self,
  1973. frame: DynamoFrameType,
  1974. cache_entry: Optional[CacheEntry],
  1975. hooks: Hooks,
  1976. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  1977. *,
  1978. skip: int = 0,
  1979. ) -> ConvertFrameReturn: ...
  1980. class CatchErrorsWrapper:
  1981. def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
  1982. functools.wraps(callback)(self)
  1983. self._torchdynamo_orig_backend = callback
  1984. self.hooks = hooks
  1985. def __call__(
  1986. self,
  1987. frame: DynamoFrameType,
  1988. cache_entry: Optional[CacheEntry],
  1989. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  1990. ) -> ConvertFrameReturn:
  1991. assert frame_state is not None
  1992. input_codes.add(frame.f_code)
  1993. is_skipfile = trace_rules.check(frame.f_code)
  1994. if sys.version_info >= (3, 13):
  1995. has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)
  1996. else:
  1997. has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code)
  1998. # Check if we should skip due to torch dispatch mode.
  1999. # When inline_torch_dispatch_torch_compile is True (new behavior), we walk
  2000. # the stack to check for active modes. When False (old behavior), we use
  2001. # the global flag that tracks if we're inside any mode.
  2002. if config.inline_torch_dispatch_torch_compile:
  2003. should_skip_for_dispatch_mode = any_torch_dispatch_mode_on_stack()
  2004. else:
  2005. should_skip_for_dispatch_mode = (
  2006. is_in_any_mode_without_ignore_compile_internals()
  2007. )
  2008. if (
  2009. # TODO: the first condition is not covered by any test
  2010. has_started_execution
  2011. or is_skipfile
  2012. or config.disable
  2013. or (
  2014. should_skip_for_dispatch_mode
  2015. and not getattr(self._torchdynamo_orig_backend, "_export", False)
  2016. )
  2017. ):
  2018. if log.isEnabledFor(logging.DEBUG):
  2019. if has_started_execution:
  2020. skip_reason = "traced frame already"
  2021. elif trace_rules.check(frame.f_code):
  2022. skip_reason = "in skipfiles"
  2023. elif should_skip_for_dispatch_mode:
  2024. skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile"
  2025. else:
  2026. skip_reason = "dynamo tracing is disabled"
  2027. log.debug(
  2028. "skipping: %s (reason: %s, file: %s)",
  2029. frame.f_code.co_name,
  2030. skip_reason,
  2031. frame.f_code.co_filename,
  2032. )
  2033. return ConvertFrameReturn()
  2034. if (
  2035. frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__"
  2036. ) or (
  2037. frame.f_code.co_filename.endswith("collections/__init__.py")
  2038. and frame.f_code.co_name == "_make"
  2039. ):
  2040. # nametuple constructor/_make
  2041. return ConvertFrameReturn()
  2042. if torch._dynamo.utils.get_optimize_ddp_mode() == "ddp_optimizer":
  2043. ddp_module = DistributedDataParallel._get_active_ddp_module()
  2044. if ddp_module:
  2045. with compile_lock:
  2046. from torch._dynamo.backends.distributed import DDPOptimizer
  2047. ddp_optimizer = DDPOptimizer(
  2048. bucket_bytes_cap=ddp_module.bucket_bytes_cap,
  2049. backend_compile_fn=self._torchdynamo_orig_backend._torchdynamo_orig_backend, # type: ignore[attr-defined]
  2050. )
  2051. assert hasattr(
  2052. self._torchdynamo_orig_backend, "_clone_with_backend"
  2053. ), (
  2054. "DDPOptimizer only supports callback fns that know how to clone themselves."
  2055. )
  2056. hijacked_callback = (
  2057. self._torchdynamo_orig_backend._clone_with_backend(
  2058. ddp_optimizer.compile_fn,
  2059. )
  2060. )
  2061. return hijacked_callback(
  2062. frame, cache_entry, self.hooks, frame_state
  2063. )
  2064. with compile_lock, _disable_current_modes():
  2065. # skip=1: skip this frame
  2066. result = self._torchdynamo_orig_backend(
  2067. frame, cache_entry, self.hooks, frame_state, skip=1
  2068. )
  2069. return result
  2070. def catch_errors_wrapper(
  2071. callback: ConvertFrameProtocol, hooks: Hooks
  2072. ) -> CatchErrorsWrapper:
  2073. return CatchErrorsWrapper(callback, hooks)