_trace.py 101 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import dataclasses
  4. import functools
  5. import inspect
  6. import logging
  7. import re
  8. import sys
  9. import time
  10. import warnings
  11. from collections.abc import Callable
  12. from contextlib import contextmanager, ExitStack, nullcontext
  13. from itertools import chain
  14. from typing import Any, TYPE_CHECKING, TypeAlias
  15. from unittest import mock
  16. if TYPE_CHECKING:
  17. import weakref
  18. import torch
  19. import torch._dynamo
  20. import torch.fx
  21. import torch.utils._pytree as pytree
  22. from torch._dispatch.python import enable_python_dispatcher
  23. from torch._dynamo.exc import UserError, UserErrorType
  24. from torch._export.db.logging import (
  25. exportdb_error_message,
  26. get_class_if_classified_error,
  27. )
  28. from torch._export.non_strict_utils import (
  29. _fakify_module_inputs,
  30. _fakify_script_objects,
  31. _gather_constant_attrs,
  32. _NonStrictTorchFunctionHandler,
  33. _override_builtin_ops,
  34. make_constraints,
  35. make_fake_inputs,
  36. produce_guards_and_solve_constraints,
  37. )
  38. from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
  39. from torch._export.passes.lift_constants_pass import (
  40. _materialize_and_lift_constants,
  41. ConstantAttrMap,
  42. )
  43. from torch._export.utils import (
  44. _collect_param_buffer_metadata,
  45. _compiling_state_context,
  46. _fakify_params_buffers,
  47. _populate_param_buffer_metadata_to_new_gm,
  48. _update_gm_meta_if_possible,
  49. apply_runtime_assertion_pass,
  50. placeholder_naming_pass,
  51. placeholder_prefixes,
  52. )
  53. from torch._export.verifier import SpecViolationError
  54. from torch._export.wrappers import _wrap_submodules
  55. from torch._functorch._aot_autograd.graph_capture_wrappers import create_functional_call
  56. from torch._functorch._aot_autograd.input_output_analysis import (
  57. _graph_input_names,
  58. _graph_output_names,
  59. )
  60. from torch._functorch._aot_autograd.schemas import GraphSignature
  61. from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container
  62. from torch._functorch._aot_autograd.utils import (
  63. create_tree_flattened_fn,
  64. register_buffer_assignment_hook,
  65. )
  66. from torch._functorch.aot_autograd import (
  67. _detect_attribute_assignment,
  68. aot_export_joint_with_descriptors,
  69. )
  70. from torch._guards import detect_fake_mode, tracing, TracingContext
  71. from torch._library.fake_class_registry import FakeScriptObject
  72. from torch._logging import dtrace_structured
  73. from torch._subclasses.fake_tensor import FakeTensorMode
  74. from torch._utils_internal import compile_time_strobelight_meta, log_export_usage
  75. from torch.export._leakage_detection_utils import find_legit_leaks_from_referrers
  76. from torch.export._unlift import _check_input_constraints_pre_hook
  77. from torch.export.dynamic_shapes import (
  78. _check_dynamic_shapes,
  79. _combine_args,
  80. _DimHintType,
  81. _IntWrapper,
  82. _process_dynamic_shapes,
  83. )
  84. from torch.export.exported_program import OutputKind
  85. from torch.fx._symbolic_trace import _ConstantAttributeType
  86. from torch.fx.experimental.proxy_tensor import (
  87. get_proxy_slot,
  88. make_fx,
  89. PreDispatchTorchFunctionMode,
  90. track_tensor_tree,
  91. )
  92. from torch.fx.experimental.symbolic_shapes import (
  93. ConstraintViolationError,
  94. free_unbacked_symbols,
  95. GuardOnDataDependentSymNode,
  96. ShapeEnv,
  97. )
  98. from torch.fx.graph import _PyTreeInfo
  99. from torch.utils._pytree import TreeSpec
  100. from torch.utils._sympy.value_ranges import ValueRangeError
  101. from .exported_program import (
  102. _disable_prexisiting_fake_mode,
  103. ExportedProgram,
  104. InputKind,
  105. ModuleCallEntry,
  106. ModuleCallSignature,
  107. )
  108. from .graph_signature import _convert_to_export_graph_signature, ExportGraphSignature
  109. log = logging.getLogger(__name__)
  110. # Type alias for dynamic shapes specification
  111. _DynamicShapesSpec: TypeAlias = dict[str, Any] | tuple[Any, ...] | list[Any]
  112. @dataclasses.dataclass
  113. class ExportDynamoConfig:
  114. """
  115. Manage Export-specific configurations of Dynamo.
  116. """
  117. allow_rnn: bool = True
  118. reorderable_logging_functions: set[Callable] = dataclasses.field(
  119. default_factory=set
  120. )
  121. # Emit runtime asserts after AOTAutograd instead.
  122. # This isn't really necessary, and isn't much more efficient since the runtime asserts pass does CSE,
  123. # but if we want to reason more about what guards/runtime asserts to emit,
  124. # this makes it a bit cleaner to do from the export side. Also no real point in running this twice.
  125. do_not_emit_runtime_asserts: bool = True
  126. specialize_int: bool = True
  127. specialize_float: bool = True
  128. assume_static_by_default: bool = False
  129. automatic_dynamic_shapes: bool = False
  130. capture_dynamic_output_shape_ops: bool = True
  131. capture_scalar_outputs: bool = True
  132. prefer_deferred_runtime_asserts_over_guards: bool = False
  133. replay_side_effects: bool = False
  134. side_effect_replay_policy: str = "warn"
  135. @dataclasses.dataclass
  136. class ATenExportArtifact:
  137. gm: torch.fx.GraphModule
  138. sig: ExportGraphSignature
  139. constants: dict[str, _ConstantAttributeType]
  140. inferred_out_spec: TreeSpec
  141. @dataclasses.dataclass(frozen=True)
  142. class ExportArtifact:
  143. aten: ATenExportArtifact
  144. in_spec: TreeSpec
  145. out_spec: TreeSpec
  146. fake_mode: FakeTensorMode
  147. module_call_specs: dict[str, dict[str, pytree.TreeSpec]]
  148. DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
  149. DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = {
  150. logging.critical,
  151. logging.debug,
  152. logging.error,
  153. logging.exception,
  154. logging.info,
  155. logging.log,
  156. logging.warning,
  157. print,
  158. warnings.warn,
  159. }
  160. @contextmanager
  161. def _ignore_backend_decomps():
  162. orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False)
  163. orig_nnpack_flag = torch.backends.nnpack.set_flags(False)
  164. orig_cudnn_flag = torch.backends.cudnn.set_flags(False)
  165. try:
  166. yield
  167. finally:
  168. torch.backends.mkldnn.set_flags(*orig_mkldnn_flag)
  169. torch.backends.nnpack.set_flags(*orig_nnpack_flag)
  170. torch.backends.cudnn.set_flags(*orig_cudnn_flag)
  171. @contextmanager
  172. def _disable_custom_triton_op_functional_decomposition():
  173. old = torch._functorch.config.decompose_custom_triton_ops
  174. try:
  175. # pyrefly: ignore [bad-assignment]
  176. torch._functorch.config.decompose_custom_triton_ops = False
  177. yield torch._functorch.config.decompose_custom_triton_ops
  178. finally:
  179. torch._functorch.config.decompose_custom_triton_ops = old
  180. def custom_triton_ops_decomposition_disabled():
  181. return not torch._functorch.config.decompose_custom_triton_ops
  182. def _fixup_key(x):
  183. return "L__self__" + _strip_root(x)
  184. def _strip_root(x):
  185. if isinstance(x, str) and x.startswith("_export_root"):
  186. stripped = x[len("_export_root") :]
  187. return stripped.removeprefix(".")
  188. return x
  189. def _is_bogus_const_name(name: str):
  190. splitted_names = name.split(".")
  191. if len(splitted_names) < 1:
  192. return True
  193. return splitted_names[-1].startswith("lifted_tensor")
  194. def _rewrite_tracepoint_node(gm: torch.fx.GraphModule):
  195. """
  196. In-place modify input graph module by replacing the export tracepoint with a new node
  197. that has the same target and args, but with the _export_root stripped from path.
  198. """
  199. for node in gm.graph.nodes:
  200. if node.target is torch.ops.higher_order._export_tracepoint:
  201. if "path" in node.kwargs:
  202. path = _strip_root(node.kwargs["path"])
  203. with gm.graph.inserting_before(node):
  204. new_node = gm.graph.create_node(
  205. "call_function",
  206. torch.ops.higher_order._export_tracepoint,
  207. args=node.args,
  208. kwargs={
  209. "path": path,
  210. "kind": node.kwargs["kind"],
  211. },
  212. )
  213. new_node.meta = node.meta
  214. node.replace_all_uses_with(new_node)
  215. gm.graph.erase_node(node)
  216. def detect_shape_env(inputs: Any = None):
  217. shape_envs = []
  218. for i, flat_input in enumerate(inputs):
  219. if isinstance(flat_input, torch.SymInt):
  220. shape_envs.append((flat_input.node.shape_env, "symint input", i))
  221. if shape_envs:
  222. shape_env, desc1, i1 = shape_envs[0]
  223. for m, desc2, i2 in shape_envs[1:]:
  224. if shape_env is not m:
  225. raise AssertionError(
  226. f"shape env ({shape_env}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
  227. f"shape env from {desc1} {i1} allocated at:\n{shape_env.stack}\n"
  228. f"shape env from {desc2} {i2} allocated at:\n{m.stack}"
  229. )
  230. return shape_env
  231. else:
  232. return None
  233. def _extract_fake_inputs(gm, args, kwargs):
  234. """
  235. Given a graph module, extract fakified input tensors from the metadata of
  236. its placeholders, and map them to the structure of given args and kwargs.
  237. Also return the fake mode used to fakify those inputs.
  238. """
  239. fake_inps: list[Any] = []
  240. fake_vals: list[Any] = []
  241. for node in gm.graph.nodes:
  242. if node.op == "placeholder":
  243. fake_inps.append(node.meta.get("val"))
  244. else:
  245. fake_vals.append(node.meta.get("example_value"))
  246. if dynamo_bytecode_flatten := getattr(gm, "_dynamo_bytecode_flatten", None):
  247. # In _extract_fake_inputs, the goal is to make real inputs into
  248. # fake (and symbolic) inputs. The way currently it's implemented
  249. # is by looking at the node.meta["val"] of the placeholder nodes.
  250. # This doesn't work when the graph is Dynamo flattened, because now
  251. # plceholder nodes doesn't have the ordering like pytree inputs do.
  252. # Instead, we need to look at how the inputs are shuffled, and map
  253. # the inputs to their actual fake inputs and symbolic inputs.
  254. # Since inputs can also contain symints, we cannot simply use the
  255. # FakeTensorMode memo to look up tensors only there.
  256. fake_inps = []
  257. positions = {}
  258. idx = 0
  259. def mark_inputs(x):
  260. # x can be a tensor or symbolic integer or a normal constant.
  261. nonlocal idx
  262. fake_inps.append(x)
  263. if isinstance(x, torch.Tensor):
  264. ret = x
  265. else:
  266. ret = object()
  267. if id(ret) not in positions:
  268. positions[id(ret)] = idx
  269. idx += 1
  270. return ret
  271. dummy_args = pytree.tree_map(mark_inputs, args + tuple(kwargs.values()))
  272. shuffled_args = dynamo_bytecode_flatten(*dummy_args)
  273. for node, shuffled_arg in zip(
  274. gm.graph.find_nodes(op="placeholder"), shuffled_args
  275. ):
  276. if id(shuffled_arg) in positions:
  277. fake_inps[positions[id(shuffled_arg)]] = node.meta.get("val")
  278. # We get both because now we might have a combination of symint and tensor
  279. # inputs, and we want to check that the shape env is consistent between
  280. # both. Unfortunately we can't see what fake mode is attached to the shape
  281. # env, then we can just compare fake modes.
  282. detected_fake_mode = detect_fake_mode(fake_inps + fake_vals)
  283. detected_shape_env = detect_shape_env(fake_inps + fake_vals)
  284. if detected_fake_mode:
  285. if detected_shape_env:
  286. if detected_shape_env is not detected_fake_mode.shape_env:
  287. raise AssertionError(
  288. "Detected shape env does not match fake mode's shape env"
  289. )
  290. fake_mode = detected_fake_mode
  291. elif detected_shape_env:
  292. fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True)
  293. else:
  294. fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True)
  295. count = 0
  296. def lookup_fake(x):
  297. nonlocal count
  298. val = fake_inps[count] if isinstance(x, (int, torch.Tensor)) else x
  299. count += 1
  300. return val
  301. fake_args = pytree.tree_map(lookup_fake, args)
  302. fake_kwargs = pytree.tree_map(lookup_fake, kwargs)
  303. return fake_args, fake_kwargs, fake_mode
  304. def _replace_param_buffer_names(param_buffer_table, sig):
  305. for spec in sig.input_specs:
  306. if spec.kind in (
  307. InputKind.PARAMETER,
  308. InputKind.BUFFER,
  309. ):
  310. spec.target = param_buffer_table[spec.target]
  311. for spec in sig.output_specs:
  312. if spec.kind in (
  313. OutputKind.BUFFER_MUTATION,
  314. OutputKind.GRADIENT_TO_PARAMETER,
  315. ):
  316. spec.target = param_buffer_table[spec.target]
  317. def _convert_to_positional_args(orig_arg_names, args, kwargs):
  318. if len(orig_arg_names) != len(args) + len(kwargs):
  319. raise AssertionError(
  320. f"Total number of arg names is expected to be {len(orig_arg_names)} "
  321. f"but got {len(args)} positional args, {len(kwargs)} kwargs."
  322. )
  323. reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]]
  324. return (
  325. *args,
  326. *reordered_kwargs,
  327. )
  328. def _normalize_nn_module_stack(gm_torch_level, root_cls):
  329. # Append a root module to every nn_module_stack.
  330. root = "L['self']"
  331. root_key = re.sub(r"[^a-zA-Z0-9]", "_", root)
  332. for gm in gm_torch_level.modules():
  333. if not isinstance(gm, torch.fx.GraphModule):
  334. continue
  335. for node in gm.graph.nodes:
  336. if node.op in ["placeholder", "output"]:
  337. continue
  338. add_root = True
  339. if nn_module_stack := node.meta.get("nn_module_stack", {}):
  340. path, ty = next(iter(nn_module_stack.values()))
  341. # After deserializing the class `ty` might not exist anymore so
  342. # it could be a string
  343. if inspect.isclass(ty) and issubclass(ty, torch.nn.Module):
  344. # TODO Figure out why sometimes we have root sometimes we don't.
  345. if path == root and ty is root_cls:
  346. add_root = False
  347. else:
  348. if not isinstance(ty, str):
  349. raise AssertionError(f"expected ty to be str, got {type(ty)}")
  350. if add_root:
  351. def normalize_path(path):
  352. if path == "L['self']":
  353. return ""
  354. if path.startswith("L['self']."):
  355. return path[len("L['self'].") :]
  356. return path
  357. nn_module_stack = {
  358. root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
  359. # pyrefly: ignore [unbound-name]
  360. **nn_module_stack,
  361. }
  362. node.meta["nn_module_stack"] = {
  363. key: (normalize_path(path), ty)
  364. for key, (path, ty) in nn_module_stack.items()
  365. }
  366. def _get_param_buffer_mapping(
  367. original_module: torch.nn.Module,
  368. traced_module: torch.nn.Module,
  369. ) -> dict[str, str]:
  370. """
  371. Returns a mapping of parameter/buffer names from the new module to the
  372. original model. This is to help with restoring the FQN for parameter/buffers
  373. of a traced module to what the original module contains.
  374. """
  375. param_lookup: dict[int, str] = {}
  376. buffer_lookup: dict[int, str] = {}
  377. for name, param in original_module.named_parameters(remove_duplicate=False):
  378. if param_lookup.get(id(param)) is None:
  379. # we only want to keep the first occurrence of a parameter to guarantee parity of original and traced module.
  380. param_lookup[id(param)] = name
  381. for name, buffer in original_module.named_buffers(remove_duplicate=False):
  382. buffer_lookup[id(buffer)] = name
  383. param_buffer_table: dict[str, str] = {}
  384. for dynamo_name, dynamo_param in traced_module.named_parameters(
  385. remove_duplicate=False
  386. ):
  387. if dynamo_name in param_buffer_table:
  388. raise AssertionError(
  389. f"dynamo_name {dynamo_name!r} already exists in param_buffer_table"
  390. )
  391. if id(dynamo_param) in param_lookup:
  392. param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)]
  393. for dynamo_name, dynamo_buffer in traced_module.named_buffers(
  394. remove_duplicate=False
  395. ):
  396. if dynamo_name in param_buffer_table:
  397. raise AssertionError(
  398. f"dynamo_name {dynamo_name!r} already exists in param_buffer_table for buffer"
  399. )
  400. if id(dynamo_buffer) in buffer_lookup:
  401. param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)]
  402. return param_buffer_table
  403. def _preserve_requires_grad_pass(
  404. gm: torch.fx.GraphModule,
  405. sig: ExportGraphSignature,
  406. fake_params_buffers: dict[str, torch.Tensor],
  407. constants: dict[str, _ConstantAttributeType],
  408. flat_fake_args: list[Any],
  409. ):
  410. placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
  411. if len(sig.input_specs) != len(placeholders):
  412. raise AssertionError(
  413. f"input_specs length {len(sig.input_specs)} does not match placeholders length {len(placeholders)}"
  414. )
  415. i = 0
  416. for node, spec in zip(placeholders, sig.input_specs):
  417. if spec.kind in (
  418. InputKind.PARAMETER,
  419. InputKind.BUFFER,
  420. ):
  421. if spec.target is None:
  422. raise AssertionError(
  423. f"spec.target must not be None for kind {spec.kind}"
  424. )
  425. node.meta["val"].requires_grad = fake_params_buffers[
  426. spec.target
  427. ].requires_grad
  428. elif spec.kind == InputKind.USER_INPUT:
  429. fake_arg = flat_fake_args[i]
  430. if isinstance(fake_arg, torch.Tensor):
  431. node.meta["val"].requires_grad = fake_arg.requires_grad
  432. i += 1
  433. elif spec.kind == InputKind.CONSTANT_TENSOR:
  434. if spec.target is None:
  435. raise AssertionError(
  436. "spec.target must not be None for CONSTANT_TENSOR kind"
  437. )
  438. constant = constants[spec.target]
  439. if isinstance(constant, torch.Tensor):
  440. # If the tensor is not leaf, it should already have a correct requires grad field
  441. if node.meta["val"].is_leaf:
  442. node.meta["val"].requires_grad = constant.requires_grad
  443. else:
  444. if node.meta["val"].requires_grad != constant.requires_grad:
  445. raise AssertionError(
  446. f"node requires_grad {node.meta['val'].requires_grad} does not match "
  447. f"constant requires_grad {constant.requires_grad}"
  448. )
  449. elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN):
  450. continue
  451. else:
  452. raise AssertionError(spec.kind)
  453. def _remap_constants(
  454. orig_constant_attrs: ConstantAttrMap,
  455. graph_signature: ExportGraphSignature,
  456. constants: dict[str, _ConstantAttributeType],
  457. ) -> None:
  458. """Rewrite the graph signature and constants table to use the FQN from the original module."""
  459. remap_table: dict[str, list[str]] = {}
  460. for name, value in constants.items():
  461. if value in orig_constant_attrs:
  462. remap_table[name] = orig_constant_attrs[value]
  463. for spec in graph_signature.input_specs:
  464. if spec.kind in (
  465. InputKind.CONSTANT_TENSOR,
  466. InputKind.CUSTOM_OBJ,
  467. ):
  468. orig_target = spec.target
  469. if orig_target is None:
  470. raise AssertionError(
  471. f"spec.target must not be None for kind {spec.kind}"
  472. )
  473. targets = remap_table.get(orig_target, [orig_target])
  474. spec.target = targets[0]
  475. constant = constants[orig_target]
  476. del constants[orig_target]
  477. for target in targets:
  478. constants[target] = constant
  479. def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None:
  480. """
  481. When we run an interpreter-based pass over a GraphModule, execution of data-dependent operators
  482. will produce example values with new unbacked symbols. To track that the new/old symbols are equivalent,
  483. we used to rely on the unbacked_renamings mapping. This led to problematic metadata where the unbacked_bindings
  484. keys mapped new symbols (u2) to paths containing old symbols (u0) in the example values, or worse, backed symbols
  485. or constants (e.g. if the original unbacked was replaced/specialized). Additionally this created problems with
  486. de/serialized programs, since we didn't comprehensively serialize ShapeEnv/unbacked renamings/node bindings.
  487. This pass attempts a simpler way of handling these for export, by throwing away the previously computed bindings, and re-running
  488. the pattern match used in compute_unbacked_bindings. This ensures we keep the original symbols contained in the example values,
  489. or delete bindings if they've been replaced/specialized.
  490. """
  491. from torch._export.utils import _get_shape_env_from_gm
  492. from torch.fx.experimental.symbolic_shapes import _free_unbacked_symbols_with_path
  493. from torch.utils._sympy.symbol import symbol_is_type, SymT
  494. if (shape_env := _get_shape_env_from_gm(gm)) is None:
  495. return
  496. base_unbacked_symbols = {
  497. symbol
  498. for symbol in shape_env.var_to_range
  499. if symbol_is_type(symbol, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))
  500. and symbol not in shape_env.unbacked_renamings
  501. }
  502. for node in gm.graph.nodes:
  503. node.meta.pop("unbacked_bindings", None)
  504. if (val := node.meta.get("val")) is not None and (
  505. unbacked_bindings := _free_unbacked_symbols_with_path(
  506. val,
  507. (),
  508. shape_env=shape_env,
  509. pending=base_unbacked_symbols,
  510. simplify=True,
  511. )
  512. ):
  513. node.meta["unbacked_bindings"] = unbacked_bindings
  514. def _produce_aten_artifact(
  515. *,
  516. gm: torch.fx.GraphModule,
  517. mod,
  518. constant_attrs,
  519. graph_signature,
  520. pre_dispatch,
  521. fake_args,
  522. fake_kwargs,
  523. fake_params_buffers,
  524. _prettify_placeholder_names=True,
  525. ) -> ATenExportArtifact:
  526. """
  527. This is a helper function that is shared between export_to_aten_ir and export_to_aten_ir_make_fx
  528. to produce the aten artifact. (export compatible graph module + signature)
  529. It does:
  530. 1. Applies runtime assertion pass
  531. 2. Recompute unbacked_bindings pass
  532. 3. Populate meta val when missing
  533. 4. Lift constants as placeholders
  534. 5. Replace raw autograd and autocast ops with HOPs
  535. 6. Prettify names for placeholders
  536. 7. Preserve requires_grad value on node meta val
  537. """
  538. # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
  539. # Overwrite output specs afterwards.
  540. flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs))
  541. gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature)
  542. # Simplify unbacked_bindings by recomputing them.
  543. # Useful for any pass that's interpreter-based and might call rebind_unbacked(),
  544. # e.g. AOTAutograd in this case.
  545. _replace_unbacked_bindings(gm)
  546. total_non_user_inputs = (
  547. len(graph_signature.parameters)
  548. + len(graph_signature.buffers)
  549. + len(graph_signature.input_tokens)
  550. )
  551. set_missing_meta_vals(gm, flat_fake_args, total_non_user_inputs)
  552. export_graph_signature: ExportGraphSignature | None
  553. export_graph_signature = _convert_to_export_graph_signature(
  554. graph_signature, gm, _get_non_persistent_buffers(mod)
  555. )
  556. # script objects are always stored in constants no matter whether they're initial inputs or
  557. # they're lifted in aot" before rewrite_script_object_meta
  558. constants = _materialize_and_lift_constants(
  559. gm, export_graph_signature, constant_attrs
  560. )
  561. if pre_dispatch:
  562. from torch._export.passes.replace_autocast_with_hop_pass import (
  563. replace_autocast_with_hop_pass,
  564. )
  565. from torch._export.passes.replace_set_grad_with_hop_pass import (
  566. replace_set_grad_with_hop_pass,
  567. )
  568. # Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because
  569. # a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass.
  570. # If replace_set_grad_with_hop_pass is before lift_constant_pass,
  571. # and the constant_tensor is passed as input of the set grad hop, the placeholder's
  572. # meta["val"] will be None and fails our verifier for placeholder.
  573. gm, export_graph_signature = replace_set_grad_with_hop_pass(
  574. gm, export_graph_signature
  575. )
  576. gm, export_graph_signature = replace_autocast_with_hop_pass(
  577. gm, export_graph_signature
  578. )
  579. # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
  580. for _mod in gm.modules():
  581. if not isinstance(_mod, torch.fx.GraphModule):
  582. continue
  583. for node in _mod.graph.nodes:
  584. if node.op in ["placeholder", "output"]:
  585. node.meta.pop("nn_module_stack", None)
  586. node.meta.pop("stack_trace", None)
  587. # Prettify names for placeholder nodes.
  588. if export_graph_signature is None:
  589. raise AssertionError("export_graph_signature must not be None")
  590. if _prettify_placeholder_names:
  591. placeholder_naming_pass(
  592. gm,
  593. export_graph_signature,
  594. mod,
  595. fake_args,
  596. fake_kwargs,
  597. fake_params_buffers,
  598. constants,
  599. )
  600. _preserve_requires_grad_pass(
  601. gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args
  602. )
  603. return ATenExportArtifact(
  604. gm,
  605. export_graph_signature,
  606. constants,
  607. inferred_out_spec=graph_signature.out_spec,
  608. )
  609. def _rename_constants_nodes(
  610. gm: torch.fx.GraphModule,
  611. graph_signature: ExportGraphSignature,
  612. ) -> None:
  613. """
  614. For strict mode, rename constants nodes that were previously annotated as buffers.
  615. """
  616. # handle name collisions with existing constants
  617. node_names = {node.name for node in gm.graph.nodes}
  618. def rename_constant(name):
  619. if name in node_names:
  620. n = 1
  621. while (dup_name := f"{name}_{n}") in node_names:
  622. n += 1
  623. # pyrefly: ignore [unbound-name]
  624. name = dup_name
  625. node_names.add(name)
  626. return name
  627. # use input specs to map names from buffers to constants
  628. buffer_prefix = placeholder_prefixes[InputKind.BUFFER]
  629. const_prefix = placeholder_prefixes[InputKind.CONSTANT_TENSOR]
  630. buffer_to_constant = {}
  631. for spec in graph_signature.input_specs:
  632. if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
  633. const_prefix
  634. ):
  635. if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
  636. c_name = rename_constant(
  637. const_prefix + spec.arg.name[len(buffer_prefix) :]
  638. )
  639. else: # lifted constant
  640. c_name = rename_constant(const_prefix + spec.arg.name)
  641. buffer_to_constant[spec.arg.name] = c_name
  642. spec.arg.name = c_name
  643. for spec in graph_signature.output_specs:
  644. if spec.arg.name in buffer_to_constant:
  645. spec.arg.name = buffer_to_constant[spec.arg.name]
  646. # Rename constants nodes for all modules
  647. for mod in gm.modules():
  648. if not isinstance(mod, torch.fx.GraphModule):
  649. continue
  650. for node in mod.graph.nodes:
  651. if node.name in buffer_to_constant:
  652. node.name = node.target = buffer_to_constant[node.name]
  653. mod.recompile()
  654. def _restore_state_dict(
  655. original_module: torch.nn.Module, traced_module: torch.fx.GraphModule
  656. ) -> None:
  657. """
  658. Restores the state dict of the traced module to that of the original module.
  659. """
  660. param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
  661. # Don't want to change the convention of previous call.
  662. param_buffer_table_reverse = {v: k for k, v in param_buffer_table.items()}
  663. # Replace state dict attr names with the fqn
  664. for name, _ in list(
  665. chain(
  666. original_module.named_parameters(remove_duplicate=False),
  667. # pyrefly: ignore [bad-argument-type]
  668. original_module.named_buffers(remove_duplicate=False),
  669. )
  670. ):
  671. if name in param_buffer_table_reverse:
  672. dynamo_name = param_buffer_table_reverse[name]
  673. param = torch.fx.graph_module._get_attr(traced_module, dynamo_name)
  674. torch.fx.graph_module._assign_attr(param, traced_module, name)
  675. torch.fx.graph_module._del_attr(traced_module, dynamo_name)
  676. # Replace graph getattr nodes with the correct name
  677. for node in traced_module.graph.nodes:
  678. if node.op == "get_attr":
  679. attr_name = node.target
  680. if attr_name in param_buffer_table:
  681. node.target = param_buffer_table[attr_name]
  682. traced_module.recompile()
  683. def _get_module_hierarchy(mod: torch.nn.Module) -> dict[str, str]:
  684. return {
  685. name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False)
  686. }
  687. def _make_module_call_graph(
  688. in_spec: TreeSpec,
  689. out_spec: TreeSpec,
  690. module_call_signatures: dict[str, ModuleCallSignature],
  691. forward_arg_names: list[str] | None = None,
  692. ) -> list[ModuleCallEntry]:
  693. original = [
  694. ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
  695. for fqn in _EXPORT_MODULE_HIERARCHY # type: ignore[union-attr]
  696. ]
  697. if original[0].fqn != "":
  698. raise AssertionError(
  699. f"expected first fqn to be empty string, got {original[0].fqn!r}"
  700. )
  701. original[0].signature = ModuleCallSignature(
  702. inputs=[],
  703. outputs=[],
  704. in_spec=in_spec,
  705. out_spec=out_spec,
  706. forward_arg_names=forward_arg_names,
  707. )
  708. additional = [
  709. ModuleCallEntry(fqn=fqn, signature=signature)
  710. for fqn, signature in module_call_signatures.items()
  711. if fqn not in _EXPORT_MODULE_HIERARCHY # type: ignore[operator]
  712. ]
  713. return [*original, *additional]
  714. class _ExportModuleSpecTrackerDict(dict):
  715. pass
  716. def _export_to_torch_ir(
  717. f: Callable,
  718. args: tuple[Any, ...],
  719. kwargs: dict[str, Any] | None = None,
  720. dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None,
  721. *,
  722. preserve_module_call_signature: tuple[str, ...] = (),
  723. disable_constraint_solver: bool = False,
  724. prefer_deferred_runtime_asserts_over_guards: bool = False,
  725. restore_fqn: bool = True,
  726. _log_export_usage: bool = True,
  727. same_signature: bool = True,
  728. ) -> torch.fx.GraphModule:
  729. """
  730. Traces either an nn.Module's forward function or just a callable with PyTorch
  731. operations inside and produce a torch.fx.GraphModule in torch IR.
  732. """
  733. if _log_export_usage:
  734. log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"})
  735. if not isinstance(args, tuple):
  736. raise UserError(
  737. UserErrorType.INVALID_INPUT,
  738. f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
  739. )
  740. kwargs = kwargs or {}
  741. # Map ints to a wrapper structure to help us mark it as dynamic, if it is
  742. # dynamic. We will unwrap ints in fakify later.
  743. args, kwargs = pytree.tree_map_only(int, _IntWrapper, (args, kwargs))
  744. combined_args = _combine_args(f, args, kwargs)
  745. _check_dynamic_shapes(combined_args, dynamic_shapes)
  746. constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
  747. # Unwrap static ints -- in the case where we have an empty graph
  748. # containing just integer computation, dynamo will run its generated
  749. # bytecode with these args/kwargs, which will error because we cannot
  750. # directly apply int operations on IntWrapper. So we will just unwrap
  751. # them here.
  752. args, kwargs = pytree.tree_map_only(
  753. _IntWrapper,
  754. lambda a: a.val
  755. if a.dynamism is None or a.dynamism.type == _DimHintType.STATIC
  756. else a,
  757. (args, kwargs),
  758. )
  759. dynamo_cfg = dataclasses.replace(
  760. DEFAULT_EXPORT_DYNAMO_CONFIG,
  761. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
  762. )
  763. def use_legacy_dynamo_graph_capture() -> bool:
  764. return bool(
  765. constraints # dynamic shape
  766. or dynamic_shapes # dynamic shape
  767. or isinstance(f, torch.fx.GraphModule) # retracing
  768. or preserve_module_call_signature # unflatten
  769. or torch._functorch.config.fake_tensor_propagate_real_tensors # draft
  770. or torch._export.config.use_legacy_dynamo_graph_capture
  771. )
  772. with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)):
  773. try:
  774. module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = (
  775. _ExportModuleSpecTrackerDict()
  776. )
  777. ctx = nullcontext()
  778. if not isinstance(f, torch.fx.GraphModule):
  779. ctx = _wrap_submodules( # type: ignore[assignment]
  780. f, preserve_module_call_signature, module_call_specs
  781. )
  782. with ctx, _ignore_backend_decomps():
  783. if torch._export.config.use_new_tracer_experimental:
  784. from torch._dynamo.functional_export import (
  785. _dynamo_graph_capture_for_export,
  786. dynamo_graph_capture_for_export,
  787. )
  788. if use_legacy_dynamo_graph_capture():
  789. dynamo_graph_capture = _dynamo_graph_capture_for_export(
  790. f, constraints=constraints, dynamic_shapes=dynamic_shapes
  791. )
  792. else:
  793. dynamo_graph_capture = torch._dynamo.config.patch(
  794. replay_side_effects=False
  795. )(dynamo_graph_capture_for_export(f))
  796. # We can't serialize entire fake mode yet, so this is to make sure
  797. # things like copy.deepcopy(ep.graph_module) not crash.
  798. # see test_export.py::test_custom_tag_metadata_re_export
  799. # Once we delete the old strict export, we can use
  800. gm_torch_level = dynamo_graph_capture(*args, **kwargs)
  801. # We can't serialize entire fake mode yet, so this is to make sure
  802. # things like copy.deepcopy(ep.graph_module) not crash.
  803. # see test_export.py::test_custom_tag_metadata_re_export
  804. # Once we delete the old strict export, we can use this fake mode in the
  805. # subsequent logic when lowering to aten IR.
  806. del gm_torch_level.meta["fake_mode"]
  807. else:
  808. gm_torch_level, _ = torch._dynamo.export(
  809. f,
  810. dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
  811. constraints=constraints, # type: ignore[arg-type]
  812. assume_static_by_default=True,
  813. tracing_mode="symbolic",
  814. disable_constraint_solver=disable_constraint_solver,
  815. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
  816. _log_export_usage=_log_export_usage,
  817. same_signature=same_signature,
  818. )(
  819. *args,
  820. **kwargs,
  821. )
  822. gm_torch_level.meta["module_call_specs"] = module_call_specs
  823. except (ConstraintViolationError, ValueRangeError) as e:
  824. raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
  825. except GuardOnDataDependentSymNode as e:
  826. raise UserError( # noqa: B904
  827. UserErrorType.ANTI_PATTERN,
  828. f"Consider annotating your code using torch._check*(). {str(e)}",
  829. case_name="constrain_as_size_example",
  830. )
  831. if isinstance(f, torch.nn.Module) and restore_fqn:
  832. _restore_state_dict(f, gm_torch_level)
  833. return gm_torch_level
  834. def _aot_export_joint_with_descriptors(
  835. stack,
  836. mod,
  837. args,
  838. *,
  839. kwargs,
  840. decompositions,
  841. fake_params_buffers,
  842. _record_nn_module_stack=True,
  843. ):
  844. from torch._functorch._aot_autograd.graph_compile import aot_stage2_export
  845. from torch._functorch._aot_autograd.input_output_analysis import (
  846. create_graph_signature,
  847. )
  848. joint_with_descriptors = aot_export_joint_with_descriptors(
  849. stack,
  850. mod,
  851. args,
  852. kwargs=kwargs,
  853. decompositions=decompositions,
  854. _record_nn_module_stack=_record_nn_module_stack,
  855. )
  856. # Convert JointWithDescriptors to graph module and ViewAndMutationMeta
  857. gm, fw_metadata = aot_stage2_export(
  858. joint_with_descriptors._aot_state,
  859. joint_with_descriptors._aot_graph_capture,
  860. )
  861. if not isinstance(gm, torch.fx.GraphModule):
  862. raise AssertionError(f"expected gm to be torch.fx.GraphModule, got {type(gm)}")
  863. # Create GraphSignature from the metadata
  864. graph_signature = create_graph_signature(
  865. gm,
  866. fw_metadata,
  867. joint_with_descriptors.in_spec,
  868. joint_with_descriptors.out_spec,
  869. user_args_flat=pytree.tree_leaves((args, kwargs)),
  870. params_and_buffers_flat=list(fake_params_buffers.values()),
  871. param_names=joint_with_descriptors.params_spec,
  872. buffer_names=joint_with_descriptors.buffers_spec,
  873. trace_joint=False,
  874. num_user_fw_outs=None,
  875. loss_index=None,
  876. )
  877. return gm, graph_signature
  878. def _export_to_aten_ir(
  879. mod: torch.nn.Module,
  880. fake_args,
  881. fake_kwargs,
  882. fake_params_buffers,
  883. constant_attrs: ConstantAttrMap,
  884. produce_guards_callback=None,
  885. *,
  886. transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later.
  887. pre_dispatch=False,
  888. decomp_table=None,
  889. _prettify_placeholder_names: bool = True,
  890. decompose_custom_triton_ops: bool = False,
  891. ) -> ATenExportArtifact:
  892. custom_triton_ops_decomposition_ctx = (
  893. nullcontext
  894. if decompose_custom_triton_ops
  895. else _disable_custom_triton_op_functional_decomposition
  896. )
  897. # This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode,
  898. # otherwise aot_export_module will error out because it sees a mix of fake_modes.
  899. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
  900. with ExitStack() as stack:
  901. stack.enter_context(
  902. torch.nn.utils.stateless._reparametrize_module(
  903. mod,
  904. fake_params_buffers,
  905. tie_weights=True,
  906. strict=True,
  907. stack_weights=True,
  908. )
  909. )
  910. stack.enter_context(_ignore_backend_decomps())
  911. stack.enter_context(_compiling_state_context())
  912. stack.enter_context(custom_triton_ops_decomposition_ctx())
  913. stack.enter_context(torch.no_grad())
  914. gm, graph_signature = transform(_aot_export_joint_with_descriptors)(
  915. stack,
  916. mod,
  917. fake_args,
  918. kwargs=fake_kwargs,
  919. decompositions=decomp_table,
  920. fake_params_buffers=fake_params_buffers,
  921. _record_nn_module_stack=True,
  922. )
  923. def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm):
  924. if isinstance(old_gm, torch.fx.GraphModule):
  925. if hasattr(old_gm, "meta"):
  926. new_gm.meta.update(old_gm.meta)
  927. old_output_node = list(old_gm.graph.nodes)[-1]
  928. new_output_node = list(new_gm.graph.nodes)[-1]
  929. if old_output_node.op != "output" or new_output_node.op != "output":
  930. raise AssertionError(
  931. f"expected both output nodes to have op='output', got old={old_output_node.op!r}, new={new_output_node.op!r}"
  932. )
  933. # make sure we don't override any meta
  934. if "desc" in new_output_node.meta:
  935. del new_output_node.meta["desc"]
  936. new_output_node.meta.update(old_output_node.meta)
  937. # TODO unfortunately preserving graph-level metadata and output node's meta
  938. # is not working well with aot_export. So we manually copy it.
  939. # (The node-level meta is addressed above.)
  940. _maybe_fixup_gm_and_output_node_meta(mod, gm)
  941. # Run produce guards before we handle runtime asserts.
  942. # This means we run the export solver before the runtime asserts pass.
  943. # Right now this doesn't mean much - the export solver is only there for suggested fixes,
  944. # and we won't even get to constraint solving if that's needed.
  945. # But if in future we want to control what runtime asserts are emitted for export,
  946. # or rely on produce_guards + solver for some simplification on runtime asserts, this probably makes sense.
  947. if produce_guards_callback:
  948. try:
  949. produce_guards_callback(gm)
  950. except (ConstraintViolationError, ValueRangeError) as e:
  951. raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
  952. return _produce_aten_artifact(
  953. gm=gm,
  954. mod=mod,
  955. constant_attrs=constant_attrs,
  956. graph_signature=graph_signature,
  957. pre_dispatch=pre_dispatch,
  958. fake_args=fake_args,
  959. fake_kwargs=fake_kwargs,
  960. fake_params_buffers=fake_params_buffers,
  961. _prettify_placeholder_names=_prettify_placeholder_names,
  962. )
  963. def _get_forward_arg_names(
  964. mod: torch.nn.Module,
  965. args: tuple[Any, ...],
  966. kwargs: dict[str, Any] | None = None,
  967. ) -> list[str]:
  968. """
  969. Gets the argument names to forward that are used, for restoring the
  970. original signature when unlifting the exported program module.
  971. - Positional args: retain the original argument names, and enumerate
  972. *args as args_0, args_1, ...
  973. - Keyword args: retain the original kwarg names in the order specified
  974. by the user. This order seems to matter for the current state of
  975. export lifted modules.
  976. """
  977. sig = inspect.signature(mod.forward)
  978. _args = sig.bind_partial(*args).arguments
  979. names: list[str] = []
  980. for name, value in _args.items():
  981. # handle variable number of positional args
  982. if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL:
  983. names.extend([f"{name}_{i}" for i, _ in enumerate(value)])
  984. else:
  985. names.append(name)
  986. # order of kwargs matters for input spec
  987. if kwargs:
  988. names.extend([kwarg for kwarg, _ in kwargs.items()])
  989. return names
  990. def _get_non_persistent_buffers(mod: torch.nn.Module) -> set[str]:
  991. """
  992. Returns set of non-persistent buffers in a module and its submodules.
  993. """
  994. result: set[str] = set()
  995. for name, m in mod.named_modules(remove_duplicate=False):
  996. if name:
  997. result.update(f"{name}.{b}" for b in m._non_persistent_buffers_set)
  998. else:
  999. result.update(m._non_persistent_buffers_set)
  1000. return result
  1001. def _rewrite_dynamo_tensor_constants(
  1002. orig_mod_buffers: set[torch.Tensor],
  1003. traced_mod_buffers: dict[str, torch.Tensor],
  1004. graph_signature: ExportGraphSignature,
  1005. constants: dict[str, _ConstantAttributeType],
  1006. ) -> None:
  1007. """
  1008. Dynamo erroneously marks tensor attributes on modules as buffers.
  1009. Rewrite them to be tensor constants.
  1010. """
  1011. for spec in graph_signature.input_specs:
  1012. if spec.kind == InputKind.BUFFER:
  1013. if spec.target is None:
  1014. raise AssertionError("spec.target must not be None for BUFFER kind")
  1015. value = traced_mod_buffers[spec.target]
  1016. if value not in orig_mod_buffers:
  1017. # This was a tensor constant erroneously marked as a buffer.
  1018. # Convert it into a constant in the graph signature, and add its
  1019. # value to the constants table.
  1020. spec.kind = InputKind.CONSTANT_TENSOR
  1021. constants[spec.target] = value # type: ignore[arg-type]
  1022. def _move_non_persistent_buffers_to_tensor_constants(
  1023. orig_mod: torch.nn.Module,
  1024. graph_signature: ExportGraphSignature,
  1025. constants: dict[str, _ConstantAttributeType],
  1026. ) -> None:
  1027. """
  1028. Moves non-persistent buffers to tensor constants.
  1029. """
  1030. for spec in graph_signature.input_specs:
  1031. if spec.kind == InputKind.BUFFER and not spec.persistent:
  1032. if spec.target is None:
  1033. raise AssertionError(
  1034. "spec.target must not be None for non-persistent BUFFER kind"
  1035. )
  1036. if spec.target in constants:
  1037. raise AssertionError(
  1038. f"spec.target {spec.target!r} should not already be in constants"
  1039. )
  1040. constants[spec.target] = orig_mod.get_buffer(spec.target) # type: ignore[arg-type]
  1041. def _verify_nn_module_stack(graph_module: torch.fx.GraphModule) -> None:
  1042. """
  1043. Perform nn_module_stack checks on the graph.
  1044. Current constraints:
  1045. For the top level graph:
  1046. - populated for 'call_function', 'get_attr'
  1047. - None for 'placeholder', 'output'
  1048. For submodule graphs:
  1049. - None for 'placeholder', output'
  1050. TODO(pianpwk): make this a consistent node-level check once nn_module_stack is populated for cond submodules.
  1051. """
  1052. # Check top-level graph for all nodes, all graphs for placeholder & output nodes
  1053. for i, mod in enumerate([graph_module] + list(graph_module.modules())):
  1054. if not isinstance(mod, torch.fx.GraphModule):
  1055. continue
  1056. for node in mod.graph.nodes:
  1057. if node.op in ["call_function", "get_attr"]:
  1058. if i == 0:
  1059. if (
  1060. nn_module_stack := node.meta.get("nn_module_stack", None)
  1061. ) is None:
  1062. raise SpecViolationError(
  1063. f"Node {node} of type {node.op} is missing nn_module_stack metadata"
  1064. )
  1065. if not all(
  1066. isinstance(k, str)
  1067. and isinstance(v, tuple)
  1068. and len(v) == 2
  1069. and all(isinstance(x, str) for x in v)
  1070. for k, v in nn_module_stack.items()
  1071. ):
  1072. raise SpecViolationError(
  1073. f"Node {node} of type {node.op} has incorrect nn_module_stack metadata format"
  1074. f"expected Dict[str, Tuple[str, str]], but got {nn_module_stack}"
  1075. )
  1076. elif node.op in ["placeholder", "output"]:
  1077. if node.meta.get("nn_module_stack", None):
  1078. raise SpecViolationError(
  1079. f"Node {node} of type {node.op} contains nn_module_stack metadata, this should be None"
  1080. )
  1081. def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None:
  1082. """
  1083. Perform stack trace checks on the graph.
  1084. Constraints:
  1085. - None or non-empty str for 'call_function', 'get_attr'
  1086. - None for 'placeholder', 'output'
  1087. """
  1088. for mod in [graph_module, *graph_module.modules()]:
  1089. if not isinstance(mod, torch.fx.GraphModule):
  1090. continue
  1091. for node in graph_module.graph.nodes:
  1092. stack_trace = node.meta.get("stack_trace", None)
  1093. if node.op in ["call_function", "get_attr"]:
  1094. if not (stack_trace is None or isinstance(stack_trace, str)):
  1095. raise SpecViolationError(
  1096. f"Node {node} of type {node.op} has invalid stack_trace metadata, "
  1097. f"expected a string or None but instead found: {stack_trace}"
  1098. )
  1099. elif node.op in ["placeholder", "output"]:
  1100. if stack_trace:
  1101. raise SpecViolationError(
  1102. f"Node {node} of type {node.op} contains stack_trace metadata, "
  1103. f"expected None but instead found: {stack_trace}"
  1104. )
  1105. def _verify_placeholder_names(
  1106. gm: torch.fx.GraphModule, sig: ExportGraphSignature
  1107. ) -> None:
  1108. """
  1109. Performs a sanity check on the placeholder node names.
  1110. - User input nodes: no restrictions, should match the original forward() signature
  1111. - Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in <placeholder_prefixes>
  1112. """
  1113. name_to_kind = {spec.arg.name: spec.kind for spec in sig.input_specs}
  1114. for mod in gm.modules():
  1115. if not isinstance(mod, torch.fx.GraphModule):
  1116. continue
  1117. for node in mod.graph.nodes:
  1118. if node.op == "placeholder":
  1119. if node.name not in name_to_kind:
  1120. continue
  1121. node_kind = name_to_kind[node.name]
  1122. prefix = placeholder_prefixes[node_kind]
  1123. if not node.name.startswith(prefix):
  1124. raise SpecViolationError(
  1125. f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}"
  1126. )
  1127. def get_ep_stats(ep: ExportedProgram) -> dict[str, Any]:
  1128. op_count = 0
  1129. op_set = set()
  1130. for m in ep.graph_module.modules():
  1131. if not isinstance(m, torch.fx.GraphModule):
  1132. continue
  1133. for node in m.graph.nodes:
  1134. if node.op != "call_function":
  1135. continue
  1136. op_count += 1
  1137. if not hasattr(node.target, "__module__"):
  1138. raise AssertionError(
  1139. f"node.target {node.target} must have __module__ attribute"
  1140. )
  1141. if not hasattr(node.target, "__name__"):
  1142. raise AssertionError(
  1143. f"node.target {node.target} must have __name__ attribute"
  1144. )
  1145. op_set.add(f"{node.target.__module__}.{node.target.__name__}")
  1146. return {"op_count": op_count, "op_set": op_set}
  1147. _EXPORT_FLAGS: set[str] | None = None
  1148. _EXPORT_MODULE_HIERARCHY: dict[str, str] | None = None
  1149. def _log_export_wrapper(fn):
  1150. @functools.wraps(fn)
  1151. def wrapper(*args, **kwargs):
  1152. global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY
  1153. try:
  1154. start = time.time()
  1155. ep = fn(*args, **kwargs)
  1156. end = time.time()
  1157. log_export_usage(
  1158. event="export.time",
  1159. metrics=end - start,
  1160. flags=_EXPORT_FLAGS,
  1161. **get_ep_stats(ep),
  1162. )
  1163. except Exception as e:
  1164. t = type(e)
  1165. error_type = t.__module__ + "." + t.__qualname__
  1166. case_name = get_class_if_classified_error(e)
  1167. if case_name is not None:
  1168. log.error(exportdb_error_message(case_name))
  1169. log_export_usage(
  1170. event="export.error.classified",
  1171. type=error_type,
  1172. message=str(e),
  1173. flags=_EXPORT_FLAGS,
  1174. )
  1175. else:
  1176. log_export_usage(
  1177. event="export.error.unclassified",
  1178. type=error_type,
  1179. message=str(e),
  1180. flags=_EXPORT_FLAGS,
  1181. )
  1182. if hasattr(e, "partial_fx_graph"):
  1183. print(
  1184. e.partial_fx_graph,
  1185. file=sys.stderr,
  1186. )
  1187. raise e
  1188. finally:
  1189. _EXPORT_FLAGS = None
  1190. _EXPORT_MODULE_HIERARCHY = None
  1191. return ep
  1192. return wrapper
  1193. def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs):
  1194. if not isinstance(example_inputs, (tuple, list, dict)):
  1195. example_inputs = (example_inputs,)
  1196. elif isinstance(example_inputs, list):
  1197. example_inputs = tuple(example_inputs)
  1198. elif (
  1199. isinstance(example_inputs, (torch.Tensor, dict))
  1200. and example_kwarg_inputs is None
  1201. ):
  1202. example_inputs = (example_inputs,)
  1203. if example_kwarg_inputs is None:
  1204. example_kwarg_inputs = {}
  1205. return example_inputs, example_kwarg_inputs
  1206. def _get_original_state_dict(mod: torch.nn.Module) -> dict[str, Any]:
  1207. # Explicitly not calling mode.state_dict() as we do not want the module state for serialization
  1208. # but the running module state so we can always match by id() the entries here with the graph inputs
  1209. named_parameters = dict(mod.named_parameters(remove_duplicate=False))
  1210. named_buffers = dict(mod.named_buffers(remove_duplicate=False))
  1211. original_state_dict = named_parameters | named_buffers
  1212. non_persistent_buffers = _get_non_persistent_buffers(mod)
  1213. for k in non_persistent_buffers:
  1214. original_state_dict.pop(k, None)
  1215. return original_state_dict
  1216. def _process_export_inputs(
  1217. mod: torch.nn.Module,
  1218. args: tuple[object, ...],
  1219. kwargs: dict[str, object] | None,
  1220. dynamic_shapes: _DynamicShapesSpec
  1221. | torch.export.AdditionalInputs
  1222. | torch.export.ShapesCollection
  1223. | None,
  1224. ) -> tuple[
  1225. tuple[object, ...],
  1226. dict[str, object],
  1227. TreeSpec,
  1228. _DynamicShapesSpec | None,
  1229. Callable[[ExportedProgram], None],
  1230. ]:
  1231. """
  1232. Process and validate export inputs for the torch.export API.
  1233. This function validates the input arguments, normalizes kwargs, computes input tree specs,
  1234. and handles special dynamic shapes cases like AdditionalInputs and ShapesCollection.
  1235. Args:
  1236. mod: The PyTorch module to be exported.
  1237. args: Tuple of example positional inputs for the module.
  1238. kwargs: Optional dictionary of example keyword inputs.
  1239. dynamic_shapes: Optional specification for dynamic shapes. Can be:
  1240. - dict mapping argument names to dynamic shape specifications
  1241. - tuple/list specifying dynamic shapes for each input in order
  1242. - torch.export.AdditionalInputs object with verification callback
  1243. - torch.export.ShapesCollection object
  1244. Returns:
  1245. A tuple containing:
  1246. - args: Validated tuple of positional inputs
  1247. - kwargs: Normalized dictionary of keyword inputs (empty dict if None was passed)
  1248. - original_in_spec: TreeSpec representing the flattened input structure
  1249. - dynamic_shapes: Processed dynamic shapes specification
  1250. - verify_additional_inputs: Callback function for additional input verification
  1251. Raises:
  1252. UserError: If args is not a tuple.
  1253. """
  1254. if not isinstance(args, tuple):
  1255. raise UserError(
  1256. UserErrorType.INVALID_INPUT,
  1257. f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
  1258. )
  1259. kwargs = kwargs if kwargs is not None else {}
  1260. if pytree.is_namedtuple_instance(args):
  1261. args = tuple(args)
  1262. _, original_in_spec = pytree.tree_flatten((args, kwargs))
  1263. verify_additional_inputs: Callable[[ExportedProgram], None]
  1264. out_dynamic_shapes: _DynamicShapesSpec | None
  1265. if isinstance(dynamic_shapes, torch.export.AdditionalInputs):
  1266. verify_additional_inputs = dynamic_shapes.verify # type: ignore[assignment]
  1267. out_dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) # type: ignore[assignment]
  1268. else:
  1269. verify_additional_inputs = lambda ep: None # noqa: E731
  1270. if isinstance(dynamic_shapes, torch.export.ShapesCollection):
  1271. out_dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) # type: ignore[assignment]
  1272. else:
  1273. out_dynamic_shapes = dynamic_shapes
  1274. return args, kwargs, original_in_spec, out_dynamic_shapes, verify_additional_inputs
  1275. def _get_module_call_graph(
  1276. export_artifact: ExportArtifact,
  1277. preserve_module_call_signature: tuple[str, ...],
  1278. strict_mode_export: bool,
  1279. forward_arg_names: list[str] | None = None,
  1280. ) -> tuple[torch.fx.GraphModule, list[ModuleCallEntry]]:
  1281. """
  1282. In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and
  1283. return module_call_graph.
  1284. """
  1285. gm: torch.fx.GraphModule = export_artifact.aten.gm
  1286. export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
  1287. module_call_specs: dict[str, dict[str, TreeSpec]] = (
  1288. export_artifact.module_call_specs
  1289. )
  1290. in_spec: TreeSpec = export_artifact.in_spec
  1291. out_spec: TreeSpec = export_artifact.out_spec
  1292. # Make module signatures.
  1293. module_call_signatures: dict[str, ModuleCallSignature] = {}
  1294. for fqn, specs in module_call_specs.items():
  1295. mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn
  1296. module_call_signatures[mod_fqn] = ModuleCallSignature(
  1297. inputs=[],
  1298. outputs=[],
  1299. in_spec=specs["in_spec"],
  1300. out_spec=specs["out_spec"],
  1301. forward_arg_names=None, # we only propagate forward_arg_names for the top level module
  1302. )
  1303. if len(preserve_module_call_signature) > 0:
  1304. if not strict_mode_export:
  1305. _rewrite_tracepoint_node(gm)
  1306. res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm)
  1307. if res is None:
  1308. raise AssertionError("CollectTracepointsPass returned None")
  1309. gm = res.graph_module
  1310. if _EXPORT_MODULE_HIERARCHY is None:
  1311. raise AssertionError("_EXPORT_MODULE_HIERARCHY must not be None")
  1312. module_call_graph = _make_module_call_graph(
  1313. in_spec,
  1314. out_spec,
  1315. module_call_signatures,
  1316. forward_arg_names,
  1317. )
  1318. return gm, module_call_graph
  1319. def _get_range_constraints(
  1320. mod: torch.nn.Module,
  1321. export_artifact: ExportArtifact,
  1322. args,
  1323. kwargs,
  1324. dynamic_shapes,
  1325. ):
  1326. gm: torch.fx.GraphModule = export_artifact.aten.gm
  1327. export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
  1328. fake_mode: FakeTensorMode = export_artifact.fake_mode
  1329. num_lifted = next(
  1330. (
  1331. i
  1332. for i, s in enumerate(export_graph_signature.input_specs)
  1333. if s.kind == InputKind.USER_INPUT
  1334. ),
  1335. len(export_graph_signature.input_specs),
  1336. )
  1337. combined_args = _combine_args(mod, args, kwargs)
  1338. # This is because we trace based on the kwargs passed in from user
  1339. # not based on the signature. I feel it would be better to just enforce
  1340. # one ordering at the start of tracing to avoid confusions, but that is
  1341. # bigger refactor, so do this to unblock for now.
  1342. combined_args_traced_order = {}
  1343. for arg in combined_args:
  1344. if arg not in kwargs:
  1345. combined_args_traced_order[arg] = combined_args[arg]
  1346. for key in kwargs:
  1347. combined_args_traced_order[key] = kwargs[key]
  1348. combined_args = combined_args_traced_order
  1349. range_constraints = make_constraints(
  1350. fake_mode,
  1351. gm,
  1352. combined_args,
  1353. dynamic_shapes,
  1354. num_lifted,
  1355. )
  1356. return range_constraints
  1357. def _get_inline_constraints(fake_mode: FakeTensorMode):
  1358. if fake_mode.shape_env is None:
  1359. raise AssertionError("fake_mode.shape_env must not be None")
  1360. return {
  1361. k: v
  1362. for k, v in fake_mode.shape_env.var_to_range.items()
  1363. if free_unbacked_symbols(k)
  1364. }
  1365. @contextmanager
  1366. def patch_forward(obj: torch.nn.Module, new_method):
  1367. """Helper method to make it easier to cleanly torch.export() a method on a
  1368. module that is not `forward`.
  1369. """
  1370. # Save the original method
  1371. original_method = obj.forward
  1372. # Patch the method
  1373. obj.forward = new_method.__get__(obj, obj.__class__)
  1374. try:
  1375. yield
  1376. finally:
  1377. # Restore the original method
  1378. obj.forward = original_method
  1379. @contextmanager
  1380. def _temp_disable_texpr_fuser():
  1381. original_state = torch._C._jit_texpr_fuser_enabled()
  1382. torch._C._jit_set_texpr_fuser_enabled(False)
  1383. try:
  1384. yield
  1385. finally:
  1386. torch._C._jit_set_texpr_fuser_enabled(original_state)
  1387. def _strict_export(
  1388. mod: torch.nn.Module,
  1389. args: tuple[Any, ...],
  1390. kwargs: dict[str, Any],
  1391. dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None,
  1392. preserve_module_call_signature: tuple[str, ...],
  1393. orig_in_spec: TreeSpec,
  1394. prefer_deferred_runtime_asserts_over_guards: bool,
  1395. _to_aten_func: Callable,
  1396. ) -> ExportArtifact:
  1397. """
  1398. _to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
  1399. """
  1400. gm_torch_level = _export_to_torch_ir(
  1401. # pyrefly: ignore [bad-argument-type]
  1402. mod,
  1403. args,
  1404. kwargs,
  1405. dynamic_shapes,
  1406. preserve_module_call_signature=preserve_module_call_signature,
  1407. restore_fqn=False, # don't need to restore because we will do it later
  1408. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
  1409. _log_export_usage=False,
  1410. )
  1411. # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
  1412. (
  1413. fake_args,
  1414. fake_kwargs,
  1415. dynamo_fake_mode,
  1416. ) = _extract_fake_inputs(gm_torch_level, args, kwargs)
  1417. fake_params_buffers = _fakify_params_buffers(dynamo_fake_mode, gm_torch_level)
  1418. # First, we want to pass through the graph to try populating
  1419. # val field for getattr if there is anything missing.
  1420. # This can happen when quantization adds extra params and forgets
  1421. # to update "val"
  1422. for node in gm_torch_level.graph.nodes:
  1423. if node.op == "get_attr" and "val" not in node.meta:
  1424. attr = getattr(gm_torch_level, node.target)
  1425. # Checks if it is not a HigherOrderOp branch or a module
  1426. if not isinstance(attr, torch.nn.Module):
  1427. if dynamo_fake_mode is None:
  1428. raise AssertionError(
  1429. "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
  1430. )
  1431. node.meta["val"] = dynamo_fake_mode.from_tensor(
  1432. attr, static_shapes=True
  1433. )
  1434. # Fix the graph output signature to be tuple if scalar
  1435. wrap_tuple = False
  1436. # Calling gm_torch_level._out_spec is not safe because gm_torch_level might be
  1437. # a _LazyGraphModule, which does not populate _out_spec when calling recompile().
  1438. # TODO: Fix recompile() in _LazyGraphModule. T207713214
  1439. if isinstance(gm_torch_level.graph._codegen, torch.fx.graph._PyTreeCodeGen):
  1440. out_spec = orig_out_spec = gm_torch_level.graph._codegen.pytree_info.out_spec
  1441. orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
  1442. # Used to get rid of lint type error.
  1443. if out_spec is None:
  1444. raise AssertionError("out_spec must not be None")
  1445. if out_spec.type not in (list, tuple):
  1446. # aot_export expect the return type to always be a tuple.
  1447. out_spec = pytree.treespec_tuple([out_spec])
  1448. wrap_tuple = True
  1449. gm_torch_level.graph._codegen.pytree_info = _PyTreeInfo(
  1450. orig_arg_names,
  1451. gm_torch_level._in_spec,
  1452. out_spec,
  1453. )
  1454. elif isinstance(
  1455. gm_torch_level.graph._codegen,
  1456. torch._dynamo.functional_export._DynamoBytecodeCodeGen,
  1457. ):
  1458. # Since we're using bytecode codegen, we need to separately apply tuple
  1459. # output instead of modifying pytree spec inplace.
  1460. orig_arg_names = gm_torch_level.graph._codegen.orig_arg_names
  1461. out_spec = orig_out_spec = None
  1462. wrap_tuple = gm_torch_level.graph._codegen.wrap_tuple = True
  1463. else:
  1464. raise RuntimeError(f"Unknown codegen type: {gm_torch_level.graph._codegen}")
  1465. gm_torch_level.recompile()
  1466. _normalize_nn_module_stack(gm_torch_level, type(mod))
  1467. params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level)
  1468. # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace)
  1469. # from the param nodes as they are treated as fresh inputs
  1470. # Therefore, we manually extract them before calling into aot_export
  1471. # params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level)
  1472. constant_attrs = _gather_constant_attrs(mod)
  1473. param_buffer_table: dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
  1474. # Dynamo does not track which buffers were registered as non-persistent. This info
  1475. # is available in the original module, so we transfer it to the traced module. Also,
  1476. # since we didn't restore original param/buffer names yet, we must use traced names.
  1477. non_persistent_buffers = _get_non_persistent_buffers(mod)
  1478. reverse_name_lookup = {orig: traced for traced, orig in param_buffer_table.items()}
  1479. gm_torch_level._non_persistent_buffers_set = {
  1480. reverse_name_lookup[name]
  1481. for name in non_persistent_buffers
  1482. if name in reverse_name_lookup
  1483. }
  1484. tx = TracingContext(dynamo_fake_mode)
  1485. with (
  1486. dynamo_fake_mode,
  1487. tracing(tx),
  1488. mock.patch.object(dynamo_fake_mode, "allow_non_fake_inputs", True),
  1489. ):
  1490. aten_export_artifact = _to_aten_func(
  1491. gm_torch_level,
  1492. # NOTE: graph module expects only positional args
  1493. _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs),
  1494. {},
  1495. fake_params_buffers,
  1496. constant_attrs,
  1497. )
  1498. # Decompose for readability.
  1499. gm = aten_export_artifact.gm
  1500. export_graph_signature = aten_export_artifact.sig
  1501. constants = aten_export_artifact.constants
  1502. _populate_param_buffer_metadata_to_new_gm(
  1503. params_buffers_to_node_meta, gm, export_graph_signature
  1504. )
  1505. # Do some cleanups on the graph module to restore the state dict to the
  1506. # expected form. Each of these steps should probably get fixed upstream.
  1507. # 1. Remove tensor constants that were added as buffers.
  1508. _rewrite_dynamo_tensor_constants(
  1509. orig_mod_buffers=set(mod.buffers()),
  1510. traced_mod_buffers=dict(gm_torch_level.named_buffers()),
  1511. graph_signature=export_graph_signature,
  1512. constants=constants,
  1513. )
  1514. # 2. Restore FQN of param/buffers
  1515. _replace_param_buffer_names(param_buffer_table, export_graph_signature)
  1516. # 3. Move non-persistent buffers to tensor constants
  1517. _move_non_persistent_buffers_to_tensor_constants(
  1518. mod, export_graph_signature, constants
  1519. )
  1520. # 4. Rewrite constants to have the same FQN as the original module.
  1521. _remap_constants(constant_attrs, export_graph_signature, constants)
  1522. # 5. Rename constants nodes in graph module from buffers to constants
  1523. _rename_constants_nodes(gm, export_graph_signature)
  1524. if orig_out_spec is None:
  1525. out_spec = aten_export_artifact.inferred_out_spec
  1526. if wrap_tuple:
  1527. out_spec = out_spec.children()[0]
  1528. else:
  1529. out_spec = orig_out_spec
  1530. return ExportArtifact(
  1531. aten=aten_export_artifact,
  1532. in_spec=orig_in_spec,
  1533. out_spec=out_spec,
  1534. fake_mode=dynamo_fake_mode,
  1535. module_call_specs=gm_torch_level.meta["module_call_specs"],
  1536. )
  1537. def _export_to_aten_ir_make_fx(
  1538. mod: torch.nn.Module,
  1539. fake_args,
  1540. fake_kwargs,
  1541. fake_params_buffers,
  1542. constant_attrs: ConstantAttrMap,
  1543. produce_guards_callback=None,
  1544. transform=lambda x: x,
  1545. ) -> ATenExportArtifact:
  1546. def _make_fx_helper(stack, mod, args, kwargs, **flags):
  1547. kwargs = kwargs or {}
  1548. named_parameters = dict(mod.named_parameters(remove_duplicate=False))
  1549. named_buffers = dict(mod.named_buffers(remove_duplicate=False))
  1550. params_and_buffers = {**named_parameters, **named_buffers}
  1551. params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)
  1552. params_and_buffers_flat = tuple(params_and_buffers_flat)
  1553. param_len = len(named_parameters)
  1554. buffer_len = len(named_buffers)
  1555. params_len = len(params_and_buffers)
  1556. functional_call = create_functional_call(
  1557. mod, params_spec, params_len, store_orig_mod=True
  1558. )
  1559. params_buffers_args: list[Any] = []
  1560. params_buffers_args.extend(params_and_buffers_flat)
  1561. params_buffers_args.extend(args)
  1562. flat_fn, out_spec = create_tree_flattened_fn(
  1563. functional_call, params_buffers_args, kwargs
  1564. )
  1565. flat_args, in_spec = pytree.tree_flatten((params_buffers_args, kwargs))
  1566. @functools.wraps(flat_fn)
  1567. def wrapped_fn(*args):
  1568. return tuple(flat_fn(*args))
  1569. with enable_python_dispatcher():
  1570. ctx = nullcontext()
  1571. non_strict_root = getattr(mod, "_export_root", None)
  1572. if non_strict_root is not None:
  1573. ctx = _detect_attribute_assignment(non_strict_root) # type: ignore[assignment]
  1574. # For any buffer that is assigned, we want to associate it to the final proxy node
  1575. # that it is assigned to. This node can then be copied into the buffer.
  1576. assigned_buffers: dict[str, str] = {}
  1577. hook = register_buffer_assignment_hook(
  1578. non_strict_root, assigned_buffers
  1579. )
  1580. def custom_getattribute(self, attr, *, original_getattr, attrs_to_proxy):
  1581. """
  1582. The idea here is that we override subclass getattr methods to proxy
  1583. inner tensors and metadata. Because of infinite loop shenanigans, we have
  1584. to manually construct the getattr proxy nodes without relying on torch function
  1585. system.
  1586. """
  1587. out = original_getattr(self, attr)
  1588. if attr in attrs_to_proxy:
  1589. if torch._C._is_torch_function_mode_enabled():
  1590. if isinstance(out, torch.Tensor):
  1591. # When we get here there is no guarantee that we will hit the
  1592. # PreDispatchTorchFunctionMode, so we manually peak into the torch
  1593. # function mode list and tweak the PreDispatchTorchFunctionMode.
  1594. # This has side effect of proxying stuff like
  1595. # proxy.node.meta["val"] = extract_val(val) because at that time, torch function
  1596. # mode is still active. It seems bad to turn it off inside proxy_tensor.py, so
  1597. # I guess we will just rely on DCE for now to remove extra stuff like detach
  1598. torch_function_mode_stack = (
  1599. torch.overrides._get_current_function_mode_stack()
  1600. )
  1601. for mode in torch_function_mode_stack:
  1602. if isinstance(mode, PreDispatchTorchFunctionMode):
  1603. tracer = mode.tracer
  1604. proxy = get_proxy_slot(self, tracer).proxy
  1605. inner_proxy = tracer.create_proxy(
  1606. "call_function",
  1607. torch.ops.export.access_subclass_inner_tensor.default,
  1608. (proxy, attr),
  1609. {},
  1610. )
  1611. track_tensor_tree(
  1612. out, inner_proxy, constant=None, tracer=tracer
  1613. )
  1614. return out
  1615. @contextmanager
  1616. def override_getattribute_for_subclasses(args):
  1617. """
  1618. Context manager that temporarily monkey patches
  1619. tensor.__getattribute__ so that we can intercept it at
  1620. torch_function layer.
  1621. """
  1622. # Dictionary that tracks subclass type to original getattr function
  1623. # and the attributes we can proxy.
  1624. tensor_type_to_old_getattribute: dict[
  1625. type[torch.Tensor], tuple[Callable, set[str]]
  1626. ] = {}
  1627. for arg in args:
  1628. subclass_types_to_instances: dict[
  1629. type[torch.Tensor], list[type[torch.Tensor]]
  1630. ] = get_subclass_typing_container(arg)
  1631. for subclass_type in subclass_types_to_instances:
  1632. if subclass_type not in tensor_type_to_old_getattribute:
  1633. if len(subclass_types_to_instances[subclass_type]) == 0:
  1634. raise AssertionError(
  1635. f"subclass_types_to_instances[{subclass_type}] must not be empty"
  1636. )
  1637. instance = subclass_types_to_instances[subclass_type][0]
  1638. # Query subclass specific attrs
  1639. attrs_to_proxy = set(dir(instance)) - set(dir(torch.Tensor))
  1640. tensor_type_to_old_getattribute[subclass_type] = (
  1641. subclass_type.__getattribute__, # type: ignore[attr-defined]
  1642. attrs_to_proxy,
  1643. )
  1644. try:
  1645. for k, (
  1646. old_getattr,
  1647. attrs_to_proxy,
  1648. ) in tensor_type_to_old_getattribute.items():
  1649. custom = functools.partialmethod(
  1650. custom_getattribute,
  1651. original_getattr=old_getattr,
  1652. attrs_to_proxy=attrs_to_proxy,
  1653. )
  1654. k.__getattribute__ = custom # type: ignore[assignment, attr-defined]
  1655. yield
  1656. finally:
  1657. for k, (old_getattr, _) in tensor_type_to_old_getattribute.items():
  1658. k.__getattribute__ = old_getattr # type: ignore[method-assign, attr-defined]
  1659. @contextmanager
  1660. def _maybe_restore_grad_state():
  1661. """
  1662. When pre-dispatch export accidentally change grad state, we restore it back.
  1663. This can happen when we are calling torch._C._set_grad_enabled directly in the
  1664. forward.
  1665. """
  1666. old_state = torch.is_grad_enabled()
  1667. try:
  1668. yield
  1669. finally:
  1670. torch._C._set_grad_enabled(old_state)
  1671. with (
  1672. ctx,
  1673. override_getattribute_for_subclasses(flat_args),
  1674. _maybe_restore_grad_state(),
  1675. ):
  1676. gm = make_fx(
  1677. wrapped_fn,
  1678. record_module_stack=True,
  1679. pre_dispatch=True,
  1680. )(*flat_args)
  1681. if non_strict_root is not None:
  1682. input_names = _graph_input_names(gm)
  1683. buffer_input_names = {
  1684. name: input_names[param_len + i]
  1685. for i, (name, buf) in enumerate(non_strict_root._buffers.items())
  1686. if buf is not None
  1687. }
  1688. output_node = list(gm.graph.nodes)[-1]
  1689. # We copy nodes corresponding to buffer assignments to buffers in the graph.
  1690. for buf, name in assigned_buffers.items(): # type: ignore[possibly-undefined]
  1691. buf_node = _find_node(gm, buffer_input_names[buf])
  1692. name_node = _find_node(gm, name)
  1693. with gm.graph.inserting_before(output_node):
  1694. new_node = gm.graph.create_node(
  1695. "call_function",
  1696. torch.ops.aten.copy_.default,
  1697. args=(buf_node, name_node),
  1698. )
  1699. new_node.meta = name_node.meta
  1700. hook.remove() # type: ignore[possibly-undefined]
  1701. def _is_impure(node):
  1702. if node.op == "call_function" and node.target in (
  1703. # In export, we ignore any op that is related to
  1704. # eager mode profiling call. The expectation is
  1705. # that either runtimes provide their own profiling
  1706. # OR user wrap the compiled region on a profiling in
  1707. # later stage.
  1708. torch.ops.profiler._record_function_enter.default,
  1709. torch.ops.profiler._record_function_enter_new.default,
  1710. torch.ops.profiler._record_function_exit._RecordFunction,
  1711. # In theory, we could fix this dead detach and getattr nodes
  1712. # from subclass tensors if we carefully rewrite track_tensor_tree
  1713. # in a way that it doesn't do any tensor methods.
  1714. torch.ops.aten.detach.default,
  1715. torch.ops.export.access_subclass_inner_tensor.default,
  1716. ):
  1717. return False
  1718. return True
  1719. gm.graph.eliminate_dead_code(_is_impure)
  1720. # create graph signature
  1721. if out_spec.spec is None:
  1722. raise AssertionError("out_spec.spec is None!")
  1723. input_names = _graph_input_names(gm)
  1724. output_names = _graph_output_names(gm)
  1725. sig = GraphSignature(
  1726. parameters=list(named_parameters),
  1727. buffers=list(named_buffers),
  1728. # pyrefly: ignore[bad-argument-type]
  1729. user_inputs=input_names[params_len:],
  1730. user_outputs=output_names,
  1731. # pyrefly: ignore[no-matching-overload]
  1732. inputs_to_parameters=dict(zip(input_names[0:param_len], named_parameters)),
  1733. # pyrefly: ignore[no-matching-overload]
  1734. inputs_to_buffers=dict(
  1735. zip(input_names[param_len : param_len + buffer_len], named_buffers)
  1736. ),
  1737. buffers_to_mutate={},
  1738. parameters_to_mutate={},
  1739. user_inputs_to_mutate={},
  1740. in_spec=in_spec,
  1741. out_spec=out_spec.spec,
  1742. backward_signature=None,
  1743. input_tokens=[],
  1744. output_tokens=[],
  1745. )
  1746. return gm, sig
  1747. # This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode,
  1748. # otherwise aot_export_module will error out because it sees a mix of fake_modes.
  1749. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
  1750. with ExitStack() as stack:
  1751. stack.enter_context(
  1752. torch.nn.utils.stateless._reparametrize_module(
  1753. mod,
  1754. fake_params_buffers,
  1755. tie_weights=True,
  1756. strict=True,
  1757. stack_weights=True,
  1758. )
  1759. )
  1760. stack.enter_context(_ignore_backend_decomps())
  1761. stack.enter_context(_compiling_state_context())
  1762. gm, graph_signature = transform(_make_fx_helper)(
  1763. stack,
  1764. mod,
  1765. fake_args,
  1766. trace_joint=False,
  1767. kwargs=fake_kwargs,
  1768. )
  1769. # [NOTE] In training IR, we don't run
  1770. # any DCE as a result we preserve constant
  1771. # nodes in the graph. make_fx invariant is that
  1772. # they don't guarantee every node gets a meta['val']
  1773. # field. Since the actual value is already hardcoded in
  1774. # graph, the node.meta here actually doesn't matter. But
  1775. # we do this to make spec verifier happy.
  1776. for node in gm.graph.nodes:
  1777. if (
  1778. node.op == "call_function"
  1779. and len(node.users) == 0
  1780. and "val" not in node.meta
  1781. ):
  1782. node.meta["val"] = None
  1783. if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
  1784. gm.meta.update(mod.meta)
  1785. # See comment in _export_to_aten_ir()
  1786. if produce_guards_callback:
  1787. try:
  1788. produce_guards_callback(gm)
  1789. except (ConstraintViolationError, ValueRangeError) as e:
  1790. raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
  1791. return _produce_aten_artifact(
  1792. gm=gm,
  1793. mod=mod,
  1794. constant_attrs=constant_attrs,
  1795. graph_signature=graph_signature,
  1796. pre_dispatch=True,
  1797. fake_args=fake_args,
  1798. fake_kwargs=fake_kwargs,
  1799. fake_params_buffers=fake_params_buffers,
  1800. )
  1801. def set_missing_meta_vals(gm, flat_args, num_params_buffers):
  1802. # Sets missing metadata to address two problems:
  1803. # 1. aot_export adds symint metadata for placeholders with int values; since
  1804. # these become specialized, we replace such metadata with the original values.
  1805. # 2. any tensor attributes that are not params / buffers, i.e., are constants
  1806. # need to have their metadata set before lifting them because it is needed
  1807. # for computing the exported program's signature.
  1808. index = 0
  1809. for node in gm.graph.nodes:
  1810. if node.op == "placeholder":
  1811. if index >= num_params_buffers:
  1812. user_arg = flat_args[index - num_params_buffers]
  1813. if not isinstance(user_arg, torch.Tensor):
  1814. node.meta["val"] = user_arg
  1815. index += 1
  1816. def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
  1817. return next(iter(node for node in gm.graph.nodes if node.name == name))
  1818. def _non_strict_export(
  1819. mod: torch.nn.Module,
  1820. args: tuple[Any, ...],
  1821. kwargs: dict[str, Any],
  1822. dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None,
  1823. preserve_module_call_signature: tuple[str, ...],
  1824. orig_in_spec: TreeSpec,
  1825. prefer_deferred_runtime_asserts_over_guards: bool,
  1826. _to_aten_func: Callable,
  1827. ) -> ExportArtifact:
  1828. """
  1829. _to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
  1830. """
  1831. out_spec: TreeSpec | None = None
  1832. in_spec: TreeSpec | None = None
  1833. module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {}
  1834. def _tuplify_outputs(aot_export):
  1835. def _aot_export_non_strict(stack, mod, args, *, kwargs=None, **flags):
  1836. kwargs = kwargs or {}
  1837. class Wrapper(torch.nn.Module):
  1838. def __init__(self, mod):
  1839. super().__init__()
  1840. self._export_root = mod
  1841. def forward(self, *args, **kwargs):
  1842. nonlocal out_spec
  1843. nonlocal in_spec
  1844. mod = self._export_root
  1845. _, in_spec = pytree.tree_flatten((args, kwargs))
  1846. if isinstance(mod, torch.fx.GraphModule):
  1847. # NOTE: We're going to run this graph module with an fx interpreter,
  1848. # which will not run any forward hooks. Thus, ideally, we should run
  1849. # all forward hooks here. But the general logic for running them is
  1850. # complicated (see nn/module.py), and probably not worth duplicating.
  1851. # Instead we only look for, and run, an export-specific forward hook.
  1852. if (
  1853. _check_input_constraints_pre_hook
  1854. in mod._forward_pre_hooks.values()
  1855. ):
  1856. _check_input_constraints_pre_hook(mod, args, kwargs)
  1857. with torch.fx.traceback.preserve_node_meta():
  1858. args = (*args, *kwargs.values())
  1859. tree_out = torch.fx.Interpreter(mod).run(*args)
  1860. else:
  1861. tree_out = mod(*args, **kwargs)
  1862. flat_outs, out_spec = pytree.tree_flatten(tree_out)
  1863. return tuple(flat_outs)
  1864. wrapped_mod = Wrapper(mod)
  1865. # Patch export_root to the signatures so that wrapper module correctly populates the
  1866. # in/out spec
  1867. new_preserved_call_signatures = [
  1868. "_export_root." + i for i in preserve_module_call_signature
  1869. ]
  1870. ctx = nullcontext()
  1871. if not isinstance(mod, torch.fx.GraphModule):
  1872. ctx = _wrap_submodules( # type: ignore[assignment]
  1873. wrapped_mod, new_preserved_call_signatures, module_call_specs
  1874. )
  1875. with ctx:
  1876. gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
  1877. log.debug("Exported program from AOTAutograd:\n%s", gm)
  1878. sig.parameters = pytree.tree_map(_strip_root, sig.parameters)
  1879. sig.buffers = pytree.tree_map(_strip_root, sig.buffers)
  1880. sig.inputs_to_buffers = pytree.tree_map(_strip_root, sig.inputs_to_buffers)
  1881. sig.inputs_to_parameters = pytree.tree_map(
  1882. _strip_root, sig.inputs_to_parameters
  1883. )
  1884. sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate)
  1885. sig.parameters_to_mutate = pytree.tree_map(
  1886. _strip_root, sig.parameters_to_mutate
  1887. )
  1888. for node in gm.graph.nodes:
  1889. if "nn_module_stack" in node.meta:
  1890. nn_module_stack = node.meta["nn_module_stack"]
  1891. node.meta["nn_module_stack"] = {
  1892. _fixup_key(key): val
  1893. for key, val in pytree.tree_map(
  1894. _strip_root, nn_module_stack
  1895. ).items()
  1896. }
  1897. return gm, sig
  1898. return _aot_export_non_strict
  1899. # NOTE: We need to enter _compiling_state_context() here so that FakeTensors
  1900. # created for params/buffers are properly tracked for leak detection.
  1901. # See detect_non_strict_fake_tensor_leaks config.
  1902. # We only enter the context if leak detection is enabled to avoid changing
  1903. # behavior when the config is OFF.
  1904. _fakify_ctx = (
  1905. _compiling_state_context()
  1906. if torch._export.config.detect_non_strict_fake_tensor_leaks
  1907. else nullcontext()
  1908. )
  1909. with _fakify_ctx:
  1910. (
  1911. fake_mode,
  1912. fake_args,
  1913. fake_kwargs,
  1914. equalities_inputs,
  1915. original_signature,
  1916. dynamic_shapes,
  1917. ) = make_fake_inputs(
  1918. mod,
  1919. args,
  1920. kwargs,
  1921. dynamic_shapes,
  1922. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization
  1923. )
  1924. fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
  1925. def _produce_guards_callback(gm):
  1926. return produce_guards_and_solve_constraints(
  1927. fake_mode=fake_mode,
  1928. gm=gm,
  1929. dynamic_shapes=dynamic_shapes,
  1930. equalities_inputs=equalities_inputs,
  1931. original_signature=original_signature,
  1932. )
  1933. tx = TracingContext(fake_mode)
  1934. # We also need to attach dynamo configs as these will be used in HOOs that
  1935. # use torch.compile, like cond
  1936. dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
  1937. dynamo_config["do_not_emit_runtime_asserts"] = (
  1938. False # We want to emit runtime asserts
  1939. )
  1940. with (
  1941. fake_mode,
  1942. _NonStrictTorchFunctionHandler(),
  1943. tracing(tx),
  1944. torch._dynamo.config.patch(dynamo_config),
  1945. ):
  1946. with (
  1947. _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
  1948. patched_mod,
  1949. new_fake_args,
  1950. new_fake_kwargs,
  1951. new_fake_constant_attrs,
  1952. map_fake_to_real,
  1953. ),
  1954. _fakify_module_inputs(fake_args, fake_kwargs, fake_mode),
  1955. _override_builtin_ops(),
  1956. ):
  1957. # _to_aten_func is _export_to_aten_ir when using the default non-strict export
  1958. # We need to pass positional args correctly
  1959. aten_export_artifact = _to_aten_func(
  1960. patched_mod,
  1961. new_fake_args,
  1962. new_fake_kwargs,
  1963. fake_params_buffers,
  1964. new_fake_constant_attrs,
  1965. produce_guards_callback=_produce_guards_callback,
  1966. transform=_tuplify_outputs,
  1967. )
  1968. # aten_export_artifact.constants contains only fake script objects, we need to map them back
  1969. aten_export_artifact.constants = {
  1970. fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj
  1971. for fqn, obj in aten_export_artifact.constants.items()
  1972. }
  1973. _move_non_persistent_buffers_to_tensor_constants(
  1974. mod, aten_export_artifact.sig, aten_export_artifact.constants
  1975. )
  1976. if out_spec is None:
  1977. raise AssertionError("out_spec must not be None")
  1978. if in_spec is None:
  1979. raise AssertionError("in_spec must not be None")
  1980. return ExportArtifact(
  1981. aten=aten_export_artifact,
  1982. in_spec=in_spec,
  1983. out_spec=out_spec,
  1984. fake_mode=fake_mode,
  1985. module_call_specs=module_call_specs,
  1986. )
  1987. @_log_export_wrapper
  1988. @_disable_prexisiting_fake_mode
  1989. def _export_for_training(
  1990. mod: torch.nn.Module,
  1991. args: tuple[Any, ...],
  1992. kwargs: dict[str, Any] | None = None,
  1993. dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None,
  1994. *,
  1995. strict: bool = True,
  1996. preserve_module_call_signature: tuple[str, ...] = (),
  1997. prefer_deferred_runtime_asserts_over_guards: bool = False,
  1998. ) -> ExportedProgram:
  1999. global _EXPORT_MODULE_HIERARCHY
  2000. _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
  2001. (
  2002. args,
  2003. kwargs,
  2004. orig_in_spec,
  2005. dynamic_shapes,
  2006. verify_additional_inputs,
  2007. ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes)
  2008. original_state_dict = _get_original_state_dict(mod)
  2009. has_ambient_mode = False
  2010. if not strict:
  2011. flat_args, _ = pytree.tree_flatten((args, kwargs))
  2012. has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
  2013. # Call the appropriate export function based on the strictness of tracing.
  2014. export_func = _strict_export if strict else _non_strict_export
  2015. if not strict and torch._export.config.detect_non_strict_fake_tensor_leaks:
  2016. from torch._subclasses.fake_tensor import fake_tensor_tls
  2017. fake_tensor_tls.non_strict_export_fake_tensor_tracker.clear()
  2018. export_artifact = export_func(
  2019. mod=mod,
  2020. args=args,
  2021. kwargs=kwargs,
  2022. dynamic_shapes=dynamic_shapes,
  2023. preserve_module_call_signature=preserve_module_call_signature,
  2024. orig_in_spec=orig_in_spec,
  2025. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
  2026. _to_aten_func=_export_to_aten_ir_make_fx,
  2027. )
  2028. # If we are tracing with fake inputs, it is expected to
  2029. # see fake tensor constants.
  2030. if not strict and not has_ambient_mode:
  2031. for const, val in export_artifact.aten.constants.items():
  2032. if isinstance(
  2033. val, torch._subclasses.fake_tensor.FakeTensor
  2034. ) and _is_bogus_const_name(const):
  2035. error_msg = (
  2036. f"We found a fake tensor in the exported program constant's list. "
  2037. f"This typically means our tracing system encountered an op that "
  2038. f"we can't trace through. For the potential source, you can refer to "
  2039. f"following model attribute: {const}. "
  2040. f"Please file an issue on github. "
  2041. )
  2042. if torch._export.config.error_on_lifted_constant_tensors:
  2043. raise RuntimeError(error_msg)
  2044. else:
  2045. warnings.warn(error_msg, stacklevel=2)
  2046. export_graph_signature = export_artifact.aten.sig
  2047. forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
  2048. inline_constraints = _get_inline_constraints(export_artifact.fake_mode)
  2049. # The unbacked symint symbols are updated in aot_export
  2050. # so we serialize them here instead of inside dynamo.
  2051. # Note: _get_range_constraints depends on "inline_constraints" to be set.
  2052. export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints
  2053. range_constraints = _get_range_constraints(
  2054. mod,
  2055. export_artifact,
  2056. args,
  2057. kwargs,
  2058. dynamic_shapes,
  2059. )
  2060. # The returned the gm is in-place modified
  2061. gm, module_call_graph = _get_module_call_graph(
  2062. export_artifact,
  2063. preserve_module_call_signature,
  2064. strict,
  2065. forward_arg_names,
  2066. )
  2067. _verify_nn_module_stack(gm)
  2068. _verify_stack_trace(gm)
  2069. _verify_placeholder_names(gm, export_graph_signature)
  2070. _update_gm_meta_if_possible(gm, mod)
  2071. from torch._export.verifier import TrainingIRVerifier
  2072. exported_program = ExportedProgram(
  2073. root=gm,
  2074. graph=gm.graph,
  2075. graph_signature=export_graph_signature,
  2076. state_dict=original_state_dict,
  2077. range_constraints=range_constraints,
  2078. module_call_graph=module_call_graph,
  2079. example_inputs=(args, kwargs),
  2080. constants=export_artifact.aten.constants,
  2081. verifiers=[TrainingIRVerifier],
  2082. )
  2083. verify_additional_inputs(exported_program)
  2084. if not strict and torch._export.config.detect_non_strict_fake_tensor_leaks:
  2085. # See NOTE [export non-strict fake tensor leak detection]
  2086. from torch._subclasses.fake_tensor import fake_tensor_tls
  2087. from torch.fx.experimental.proxy_tensor import (
  2088. _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT,
  2089. )
  2090. active_fakes = fake_tensor_tls.non_strict_export_fake_tensor_tracker
  2091. legit_leak: weakref.WeakSet = find_legit_leaks_from_referrers(active_fakes)
  2092. leak_sources: list[str] = []
  2093. if len(legit_leak) > 0:
  2094. for fake_val in legit_leak:
  2095. if id(fake_val) in _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT:
  2096. node = _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT[id(fake_val)]
  2097. stack_trace = node.meta.get("stack_trace")
  2098. node_name = node.name
  2099. # If no stack trace on this node (e.g., placeholder), look at users
  2100. if stack_trace is None:
  2101. for user in node.users:
  2102. user_stack = user.meta.get("stack_trace")
  2103. if user_stack is not None:
  2104. stack_trace = f"Used by '{user.name}':\n{user_stack}"
  2105. break
  2106. stack_trace = (
  2107. "<no stack trace available>"
  2108. if stack_trace is None
  2109. else stack_trace
  2110. )
  2111. # Get shape and dtype info
  2112. shape_info = f"shape={fake_val.shape}, dtype={fake_val.dtype}"
  2113. leak_info = f"FakeTensor({shape_info}) from node '{node_name}':\n{stack_trace}"
  2114. leak_sources.append(leak_info)
  2115. else:
  2116. # Fallback: no proxy mapping found, show basic info
  2117. shape_info = f"shape={fake_val.shape}, dtype={fake_val.dtype}"
  2118. leak_info = f"FakeTensor({shape_info}): <no proxy mapping found>"
  2119. leak_sources.append(leak_info)
  2120. # Format the warning message more nicely
  2121. leak_details = "\n ".join(leak_sources)
  2122. warnings.warn(
  2123. f"Detected {len(legit_leak)} fake tensors that are still alive after export.\n"
  2124. f"This is likely result of torch.export.export not being able to track side effects "
  2125. f"that is happening outside of model scope.\n\n"
  2126. f"Leaked tensors:\n {leak_details}\n\n"
  2127. f"Alternatively, please file a bug report to PyTorch team for further debugging help.",
  2128. stacklevel=2,
  2129. )
  2130. del legit_leak
  2131. return exported_program
  2132. @_log_export_wrapper
  2133. @_disable_prexisiting_fake_mode
  2134. @compile_time_strobelight_meta(phase_name="export")
  2135. def _export(
  2136. mod: torch.nn.Module,
  2137. args: tuple[Any, ...],
  2138. kwargs: dict[str, Any] | None = None,
  2139. dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None,
  2140. *,
  2141. strict: bool = True,
  2142. preserve_module_call_signature: tuple[str, ...] = (),
  2143. pre_dispatch: bool = False,
  2144. prefer_deferred_runtime_asserts_over_guards: bool = False,
  2145. ) -> ExportedProgram:
  2146. """
  2147. Traces either an nn.Module's forward function or just a callable with PyTorch
  2148. operations inside and produce a ExportedProgram.
  2149. Args:
  2150. mod: the `nn.Module` to trace.
  2151. args: example positional inputs.
  2152. kwargs: optional example keyword inputs.
  2153. dynamic_shapes:
  2154. An optional argument where the type should either be:
  2155. 1) a dict from argument names of ``f`` to their dynamic shape specifications,
  2156. 2) a tuple that specifies dynamic shape specifications for each input in original order.
  2157. If you are specifying dynamism on keyword args, you will need to pass them in the order that
  2158. is defined in the original function signature.
  2159. The dynamic shape of a tensor argument can be specified as either
  2160. (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
  2161. not required to include static dimension indices in this dict, but when they are,
  2162. they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
  2163. where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
  2164. are denoted by None. Arguments that are dicts or tuples / lists of tensors are
  2165. recursively specified by using mappings or sequences of contained specifications.
  2166. preserve_module_call_signature: A list of submodule paths for which the original
  2167. calling conventions are preserved as metadata.
  2168. prefer_deferred_runtime_asserts_over_guards:
  2169. With the current dynamic shapes language for dims and derived dims, we can run into constraints
  2170. that are not expressible with the language. For example, flattening a matrix and adding to a vector,
  2171. both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
  2172. By default, we either raise a constraint violation error or specialize to static values.
  2173. If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime
  2174. assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops
  2175. required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar).
  2176. Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints
  2177. while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes.
  2178. Returns:
  2179. An ExportedProgram containing the traced module.
  2180. """
  2181. from torch._utils_internal import export_training_ir_rollout_check
  2182. global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY
  2183. _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
  2184. flags = set()
  2185. flags.add("strict" if strict else "non_strict")
  2186. flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch")
  2187. _EXPORT_FLAGS = flags
  2188. log_export_usage(event="export.enter", flags=_EXPORT_FLAGS)
  2189. dtrace_structured("export", payload_fn=lambda: "start!")
  2190. # NOTE Export training IR rollout
  2191. # Old export calls export._trace(pre_dispatch=True)
  2192. # and there are still lot of internal/OSS callsites that
  2193. # use export._trace(pre_dispatch=True) directly. Therefore,
  2194. # it makes more sense to do the switch here.
  2195. # export_training_ir_rollout_check returns True in OSS
  2196. # while internally it returns False UNLESS otherwise specified.
  2197. if pre_dispatch and export_training_ir_rollout_check():
  2198. ep = _export_for_training(
  2199. mod,
  2200. args,
  2201. kwargs,
  2202. dynamic_shapes,
  2203. strict=strict,
  2204. preserve_module_call_signature=preserve_module_call_signature,
  2205. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
  2206. )
  2207. dtrace_structured("exported_program", payload_fn=lambda: str(ep))
  2208. return ep
  2209. (
  2210. args,
  2211. kwargs,
  2212. original_in_spec,
  2213. dynamic_shapes,
  2214. verify_additional_inputs,
  2215. ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes)
  2216. original_state_dict = _get_original_state_dict(mod)
  2217. # Call the appropriate export function based on the strictness of tracing.
  2218. export_func = _strict_export if strict else _non_strict_export
  2219. export_artifact = export_func( # type: ignore[operator]
  2220. mod=mod,
  2221. args=args,
  2222. kwargs=kwargs,
  2223. dynamic_shapes=dynamic_shapes,
  2224. preserve_module_call_signature=preserve_module_call_signature,
  2225. orig_in_spec=original_in_spec,
  2226. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
  2227. _to_aten_func=functools.partial(
  2228. _export_to_aten_ir,
  2229. pre_dispatch=pre_dispatch,
  2230. ),
  2231. )
  2232. export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
  2233. forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
  2234. inline_constraints = _get_inline_constraints(export_artifact.fake_mode)
  2235. # The unbacked symint symbols are updated in aot_export
  2236. # so we serialize them here instead of inside dynamo.
  2237. # Note: this step must be before _get_range_constraints.
  2238. export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints
  2239. range_constraints = _get_range_constraints(
  2240. mod,
  2241. export_artifact,
  2242. args,
  2243. kwargs,
  2244. dynamic_shapes,
  2245. )
  2246. gm, module_call_graph = _get_module_call_graph(
  2247. export_artifact,
  2248. preserve_module_call_signature,
  2249. strict,
  2250. forward_arg_names,
  2251. )
  2252. _verify_nn_module_stack(gm)
  2253. _verify_stack_trace(gm)
  2254. _verify_placeholder_names(gm, export_graph_signature)
  2255. # Remove Proxy because they cannot be deepcopied or pickled.
  2256. torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True)
  2257. from torch._export.verifier import Verifier
  2258. _update_gm_meta_if_possible(gm, mod)
  2259. exported_program = ExportedProgram(
  2260. root=gm,
  2261. graph=gm.graph,
  2262. graph_signature=export_graph_signature,
  2263. state_dict=original_state_dict,
  2264. range_constraints=range_constraints,
  2265. module_call_graph=module_call_graph,
  2266. example_inputs=(args, kwargs),
  2267. constants=export_artifact.aten.constants,
  2268. verifiers=[Verifier],
  2269. )
  2270. dtrace_structured("exported_program", payload_fn=lambda: str(exported_program))
  2271. verify_additional_inputs(exported_program)
  2272. return exported_program