| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594 |
- # mypy: allow-untyped-decorators
- # mypy: allow-untyped-defs
- import dataclasses
- import functools
- import inspect
- import logging
- import re
- import sys
- import time
- import warnings
- from collections.abc import Callable
- from contextlib import contextmanager, ExitStack, nullcontext
- from itertools import chain
- from typing import Any, TYPE_CHECKING, TypeAlias
- from unittest import mock
- if TYPE_CHECKING:
- import weakref
- import torch
- import torch._dynamo
- import torch.fx
- import torch.utils._pytree as pytree
- from torch._dispatch.python import enable_python_dispatcher
- from torch._dynamo.exc import UserError, UserErrorType
- from torch._export.db.logging import (
- exportdb_error_message,
- get_class_if_classified_error,
- )
- from torch._export.non_strict_utils import (
- _fakify_module_inputs,
- _fakify_script_objects,
- _gather_constant_attrs,
- _NonStrictTorchFunctionHandler,
- _override_builtin_ops,
- make_constraints,
- make_fake_inputs,
- produce_guards_and_solve_constraints,
- )
- from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
- from torch._export.passes.lift_constants_pass import (
- _materialize_and_lift_constants,
- ConstantAttrMap,
- )
- from torch._export.utils import (
- _collect_param_buffer_metadata,
- _compiling_state_context,
- _fakify_params_buffers,
- _populate_param_buffer_metadata_to_new_gm,
- _update_gm_meta_if_possible,
- apply_runtime_assertion_pass,
- placeholder_naming_pass,
- placeholder_prefixes,
- )
- from torch._export.verifier import SpecViolationError
- from torch._export.wrappers import _wrap_submodules
- from torch._functorch._aot_autograd.graph_capture_wrappers import create_functional_call
- from torch._functorch._aot_autograd.input_output_analysis import (
- _graph_input_names,
- _graph_output_names,
- )
- from torch._functorch._aot_autograd.schemas import GraphSignature
- from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container
- from torch._functorch._aot_autograd.utils import (
- create_tree_flattened_fn,
- register_buffer_assignment_hook,
- )
- from torch._functorch.aot_autograd import (
- _detect_attribute_assignment,
- aot_export_joint_with_descriptors,
- )
- from torch._guards import detect_fake_mode, tracing, TracingContext
- from torch._library.fake_class_registry import FakeScriptObject
- from torch._logging import dtrace_structured
- from torch._subclasses.fake_tensor import FakeTensorMode
- from torch._utils_internal import compile_time_strobelight_meta, log_export_usage
- from torch.export._leakage_detection_utils import find_legit_leaks_from_referrers
- from torch.export._unlift import _check_input_constraints_pre_hook
- from torch.export.dynamic_shapes import (
- _check_dynamic_shapes,
- _combine_args,
- _DimHintType,
- _IntWrapper,
- _process_dynamic_shapes,
- )
- from torch.export.exported_program import OutputKind
- from torch.fx._symbolic_trace import _ConstantAttributeType
- from torch.fx.experimental.proxy_tensor import (
- get_proxy_slot,
- make_fx,
- PreDispatchTorchFunctionMode,
- track_tensor_tree,
- )
- from torch.fx.experimental.symbolic_shapes import (
- ConstraintViolationError,
- free_unbacked_symbols,
- GuardOnDataDependentSymNode,
- ShapeEnv,
- )
- from torch.fx.graph import _PyTreeInfo
- from torch.utils._pytree import TreeSpec
- from torch.utils._sympy.value_ranges import ValueRangeError
- from .exported_program import (
- _disable_prexisiting_fake_mode,
- ExportedProgram,
- InputKind,
- ModuleCallEntry,
- ModuleCallSignature,
- )
- from .graph_signature import _convert_to_export_graph_signature, ExportGraphSignature
- log = logging.getLogger(__name__)
- # Type alias for dynamic shapes specification
- _DynamicShapesSpec: TypeAlias = dict[str, Any] | tuple[Any, ...] | list[Any]
- @dataclasses.dataclass
- class ExportDynamoConfig:
- """
- Manage Export-specific configurations of Dynamo.
- """
- allow_rnn: bool = True
- reorderable_logging_functions: set[Callable] = dataclasses.field(
- default_factory=set
- )
- # Emit runtime asserts after AOTAutograd instead.
- # This isn't really necessary, and isn't much more efficient since the runtime asserts pass does CSE,
- # but if we want to reason more about what guards/runtime asserts to emit,
- # this makes it a bit cleaner to do from the export side. Also no real point in running this twice.
- do_not_emit_runtime_asserts: bool = True
- specialize_int: bool = True
- specialize_float: bool = True
- assume_static_by_default: bool = False
- automatic_dynamic_shapes: bool = False
- capture_dynamic_output_shape_ops: bool = True
- capture_scalar_outputs: bool = True
- prefer_deferred_runtime_asserts_over_guards: bool = False
- replay_side_effects: bool = False
- side_effect_replay_policy: str = "warn"
- @dataclasses.dataclass
- class ATenExportArtifact:
- gm: torch.fx.GraphModule
- sig: ExportGraphSignature
- constants: dict[str, _ConstantAttributeType]
- inferred_out_spec: TreeSpec
- @dataclasses.dataclass(frozen=True)
- class ExportArtifact:
- aten: ATenExportArtifact
- in_spec: TreeSpec
- out_spec: TreeSpec
- fake_mode: FakeTensorMode
- module_call_specs: dict[str, dict[str, pytree.TreeSpec]]
- DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
- DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = {
- logging.critical,
- logging.debug,
- logging.error,
- logging.exception,
- logging.info,
- logging.log,
- logging.warning,
- print,
- warnings.warn,
- }
- @contextmanager
- def _ignore_backend_decomps():
- orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False)
- orig_nnpack_flag = torch.backends.nnpack.set_flags(False)
- orig_cudnn_flag = torch.backends.cudnn.set_flags(False)
- try:
- yield
- finally:
- torch.backends.mkldnn.set_flags(*orig_mkldnn_flag)
- torch.backends.nnpack.set_flags(*orig_nnpack_flag)
- torch.backends.cudnn.set_flags(*orig_cudnn_flag)
- @contextmanager
- def _disable_custom_triton_op_functional_decomposition():
- old = torch._functorch.config.decompose_custom_triton_ops
- try:
- # pyrefly: ignore [bad-assignment]
- torch._functorch.config.decompose_custom_triton_ops = False
- yield torch._functorch.config.decompose_custom_triton_ops
- finally:
- torch._functorch.config.decompose_custom_triton_ops = old
- def custom_triton_ops_decomposition_disabled():
- return not torch._functorch.config.decompose_custom_triton_ops
- def _fixup_key(x):
- return "L__self__" + _strip_root(x)
- def _strip_root(x):
- if isinstance(x, str) and x.startswith("_export_root"):
- stripped = x[len("_export_root") :]
- return stripped.removeprefix(".")
- return x
- def _is_bogus_const_name(name: str):
- splitted_names = name.split(".")
- if len(splitted_names) < 1:
- return True
- return splitted_names[-1].startswith("lifted_tensor")
- def _rewrite_tracepoint_node(gm: torch.fx.GraphModule):
- """
- In-place modify input graph module by replacing the export tracepoint with a new node
- that has the same target and args, but with the _export_root stripped from path.
- """
- for node in gm.graph.nodes:
- if node.target is torch.ops.higher_order._export_tracepoint:
- if "path" in node.kwargs:
- path = _strip_root(node.kwargs["path"])
- with gm.graph.inserting_before(node):
- new_node = gm.graph.create_node(
- "call_function",
- torch.ops.higher_order._export_tracepoint,
- args=node.args,
- kwargs={
- "path": path,
- "kind": node.kwargs["kind"],
- },
- )
- new_node.meta = node.meta
- node.replace_all_uses_with(new_node)
- gm.graph.erase_node(node)
- def detect_shape_env(inputs: Any = None):
- shape_envs = []
- for i, flat_input in enumerate(inputs):
- if isinstance(flat_input, torch.SymInt):
- shape_envs.append((flat_input.node.shape_env, "symint input", i))
- if shape_envs:
- shape_env, desc1, i1 = shape_envs[0]
- for m, desc2, i2 in shape_envs[1:]:
- if shape_env is not m:
- raise AssertionError(
- f"shape env ({shape_env}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
- f"shape env from {desc1} {i1} allocated at:\n{shape_env.stack}\n"
- f"shape env from {desc2} {i2} allocated at:\n{m.stack}"
- )
- return shape_env
- else:
- return None
- def _extract_fake_inputs(gm, args, kwargs):
- """
- Given a graph module, extract fakified input tensors from the metadata of
- its placeholders, and map them to the structure of given args and kwargs.
- Also return the fake mode used to fakify those inputs.
- """
- fake_inps: list[Any] = []
- fake_vals: list[Any] = []
- for node in gm.graph.nodes:
- if node.op == "placeholder":
- fake_inps.append(node.meta.get("val"))
- else:
- fake_vals.append(node.meta.get("example_value"))
- if dynamo_bytecode_flatten := getattr(gm, "_dynamo_bytecode_flatten", None):
- # In _extract_fake_inputs, the goal is to make real inputs into
- # fake (and symbolic) inputs. The way currently it's implemented
- # is by looking at the node.meta["val"] of the placeholder nodes.
- # This doesn't work when the graph is Dynamo flattened, because now
- # plceholder nodes doesn't have the ordering like pytree inputs do.
- # Instead, we need to look at how the inputs are shuffled, and map
- # the inputs to their actual fake inputs and symbolic inputs.
- # Since inputs can also contain symints, we cannot simply use the
- # FakeTensorMode memo to look up tensors only there.
- fake_inps = []
- positions = {}
- idx = 0
- def mark_inputs(x):
- # x can be a tensor or symbolic integer or a normal constant.
- nonlocal idx
- fake_inps.append(x)
- if isinstance(x, torch.Tensor):
- ret = x
- else:
- ret = object()
- if id(ret) not in positions:
- positions[id(ret)] = idx
- idx += 1
- return ret
- dummy_args = pytree.tree_map(mark_inputs, args + tuple(kwargs.values()))
- shuffled_args = dynamo_bytecode_flatten(*dummy_args)
- for node, shuffled_arg in zip(
- gm.graph.find_nodes(op="placeholder"), shuffled_args
- ):
- if id(shuffled_arg) in positions:
- fake_inps[positions[id(shuffled_arg)]] = node.meta.get("val")
- # We get both because now we might have a combination of symint and tensor
- # inputs, and we want to check that the shape env is consistent between
- # both. Unfortunately we can't see what fake mode is attached to the shape
- # env, then we can just compare fake modes.
- detected_fake_mode = detect_fake_mode(fake_inps + fake_vals)
- detected_shape_env = detect_shape_env(fake_inps + fake_vals)
- if detected_fake_mode:
- if detected_shape_env:
- if detected_shape_env is not detected_fake_mode.shape_env:
- raise AssertionError(
- "Detected shape env does not match fake mode's shape env"
- )
- fake_mode = detected_fake_mode
- elif detected_shape_env:
- fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True)
- else:
- fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True)
- count = 0
- def lookup_fake(x):
- nonlocal count
- val = fake_inps[count] if isinstance(x, (int, torch.Tensor)) else x
- count += 1
- return val
- fake_args = pytree.tree_map(lookup_fake, args)
- fake_kwargs = pytree.tree_map(lookup_fake, kwargs)
- return fake_args, fake_kwargs, fake_mode
- def _replace_param_buffer_names(param_buffer_table, sig):
- for spec in sig.input_specs:
- if spec.kind in (
- InputKind.PARAMETER,
- InputKind.BUFFER,
- ):
- spec.target = param_buffer_table[spec.target]
- for spec in sig.output_specs:
- if spec.kind in (
- OutputKind.BUFFER_MUTATION,
- OutputKind.GRADIENT_TO_PARAMETER,
- ):
- spec.target = param_buffer_table[spec.target]
- def _convert_to_positional_args(orig_arg_names, args, kwargs):
- if len(orig_arg_names) != len(args) + len(kwargs):
- raise AssertionError(
- f"Total number of arg names is expected to be {len(orig_arg_names)} "
- f"but got {len(args)} positional args, {len(kwargs)} kwargs."
- )
- reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]]
- return (
- *args,
- *reordered_kwargs,
- )
- def _normalize_nn_module_stack(gm_torch_level, root_cls):
- # Append a root module to every nn_module_stack.
- root = "L['self']"
- root_key = re.sub(r"[^a-zA-Z0-9]", "_", root)
- for gm in gm_torch_level.modules():
- if not isinstance(gm, torch.fx.GraphModule):
- continue
- for node in gm.graph.nodes:
- if node.op in ["placeholder", "output"]:
- continue
- add_root = True
- if nn_module_stack := node.meta.get("nn_module_stack", {}):
- path, ty = next(iter(nn_module_stack.values()))
- # After deserializing the class `ty` might not exist anymore so
- # it could be a string
- if inspect.isclass(ty) and issubclass(ty, torch.nn.Module):
- # TODO Figure out why sometimes we have root sometimes we don't.
- if path == root and ty is root_cls:
- add_root = False
- else:
- if not isinstance(ty, str):
- raise AssertionError(f"expected ty to be str, got {type(ty)}")
- if add_root:
- def normalize_path(path):
- if path == "L['self']":
- return ""
- if path.startswith("L['self']."):
- return path[len("L['self'].") :]
- return path
- nn_module_stack = {
- root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
- # pyrefly: ignore [unbound-name]
- **nn_module_stack,
- }
- node.meta["nn_module_stack"] = {
- key: (normalize_path(path), ty)
- for key, (path, ty) in nn_module_stack.items()
- }
- def _get_param_buffer_mapping(
- original_module: torch.nn.Module,
- traced_module: torch.nn.Module,
- ) -> dict[str, str]:
- """
- Returns a mapping of parameter/buffer names from the new module to the
- original model. This is to help with restoring the FQN for parameter/buffers
- of a traced module to what the original module contains.
- """
- param_lookup: dict[int, str] = {}
- buffer_lookup: dict[int, str] = {}
- for name, param in original_module.named_parameters(remove_duplicate=False):
- if param_lookup.get(id(param)) is None:
- # we only want to keep the first occurrence of a parameter to guarantee parity of original and traced module.
- param_lookup[id(param)] = name
- for name, buffer in original_module.named_buffers(remove_duplicate=False):
- buffer_lookup[id(buffer)] = name
- param_buffer_table: dict[str, str] = {}
- for dynamo_name, dynamo_param in traced_module.named_parameters(
- remove_duplicate=False
- ):
- if dynamo_name in param_buffer_table:
- raise AssertionError(
- f"dynamo_name {dynamo_name!r} already exists in param_buffer_table"
- )
- if id(dynamo_param) in param_lookup:
- param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)]
- for dynamo_name, dynamo_buffer in traced_module.named_buffers(
- remove_duplicate=False
- ):
- if dynamo_name in param_buffer_table:
- raise AssertionError(
- f"dynamo_name {dynamo_name!r} already exists in param_buffer_table for buffer"
- )
- if id(dynamo_buffer) in buffer_lookup:
- param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)]
- return param_buffer_table
- def _preserve_requires_grad_pass(
- gm: torch.fx.GraphModule,
- sig: ExportGraphSignature,
- fake_params_buffers: dict[str, torch.Tensor],
- constants: dict[str, _ConstantAttributeType],
- flat_fake_args: list[Any],
- ):
- placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
- if len(sig.input_specs) != len(placeholders):
- raise AssertionError(
- f"input_specs length {len(sig.input_specs)} does not match placeholders length {len(placeholders)}"
- )
- i = 0
- for node, spec in zip(placeholders, sig.input_specs):
- if spec.kind in (
- InputKind.PARAMETER,
- InputKind.BUFFER,
- ):
- if spec.target is None:
- raise AssertionError(
- f"spec.target must not be None for kind {spec.kind}"
- )
- node.meta["val"].requires_grad = fake_params_buffers[
- spec.target
- ].requires_grad
- elif spec.kind == InputKind.USER_INPUT:
- fake_arg = flat_fake_args[i]
- if isinstance(fake_arg, torch.Tensor):
- node.meta["val"].requires_grad = fake_arg.requires_grad
- i += 1
- elif spec.kind == InputKind.CONSTANT_TENSOR:
- if spec.target is None:
- raise AssertionError(
- "spec.target must not be None for CONSTANT_TENSOR kind"
- )
- constant = constants[spec.target]
- if isinstance(constant, torch.Tensor):
- # If the tensor is not leaf, it should already have a correct requires grad field
- if node.meta["val"].is_leaf:
- node.meta["val"].requires_grad = constant.requires_grad
- else:
- if node.meta["val"].requires_grad != constant.requires_grad:
- raise AssertionError(
- f"node requires_grad {node.meta['val'].requires_grad} does not match "
- f"constant requires_grad {constant.requires_grad}"
- )
- elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN):
- continue
- else:
- raise AssertionError(spec.kind)
- def _remap_constants(
- orig_constant_attrs: ConstantAttrMap,
- graph_signature: ExportGraphSignature,
- constants: dict[str, _ConstantAttributeType],
- ) -> None:
- """Rewrite the graph signature and constants table to use the FQN from the original module."""
- remap_table: dict[str, list[str]] = {}
- for name, value in constants.items():
- if value in orig_constant_attrs:
- remap_table[name] = orig_constant_attrs[value]
- for spec in graph_signature.input_specs:
- if spec.kind in (
- InputKind.CONSTANT_TENSOR,
- InputKind.CUSTOM_OBJ,
- ):
- orig_target = spec.target
- if orig_target is None:
- raise AssertionError(
- f"spec.target must not be None for kind {spec.kind}"
- )
- targets = remap_table.get(orig_target, [orig_target])
- spec.target = targets[0]
- constant = constants[orig_target]
- del constants[orig_target]
- for target in targets:
- constants[target] = constant
- def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None:
- """
- When we run an interpreter-based pass over a GraphModule, execution of data-dependent operators
- will produce example values with new unbacked symbols. To track that the new/old symbols are equivalent,
- we used to rely on the unbacked_renamings mapping. This led to problematic metadata where the unbacked_bindings
- keys mapped new symbols (u2) to paths containing old symbols (u0) in the example values, or worse, backed symbols
- or constants (e.g. if the original unbacked was replaced/specialized). Additionally this created problems with
- de/serialized programs, since we didn't comprehensively serialize ShapeEnv/unbacked renamings/node bindings.
- This pass attempts a simpler way of handling these for export, by throwing away the previously computed bindings, and re-running
- the pattern match used in compute_unbacked_bindings. This ensures we keep the original symbols contained in the example values,
- or delete bindings if they've been replaced/specialized.
- """
- from torch._export.utils import _get_shape_env_from_gm
- from torch.fx.experimental.symbolic_shapes import _free_unbacked_symbols_with_path
- from torch.utils._sympy.symbol import symbol_is_type, SymT
- if (shape_env := _get_shape_env_from_gm(gm)) is None:
- return
- base_unbacked_symbols = {
- symbol
- for symbol in shape_env.var_to_range
- if symbol_is_type(symbol, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))
- and symbol not in shape_env.unbacked_renamings
- }
- for node in gm.graph.nodes:
- node.meta.pop("unbacked_bindings", None)
- if (val := node.meta.get("val")) is not None and (
- unbacked_bindings := _free_unbacked_symbols_with_path(
- val,
- (),
- shape_env=shape_env,
- pending=base_unbacked_symbols,
- simplify=True,
- )
- ):
- node.meta["unbacked_bindings"] = unbacked_bindings
- def _produce_aten_artifact(
- *,
- gm: torch.fx.GraphModule,
- mod,
- constant_attrs,
- graph_signature,
- pre_dispatch,
- fake_args,
- fake_kwargs,
- fake_params_buffers,
- _prettify_placeholder_names=True,
- ) -> ATenExportArtifact:
- """
- This is a helper function that is shared between export_to_aten_ir and export_to_aten_ir_make_fx
- to produce the aten artifact. (export compatible graph module + signature)
- It does:
- 1. Applies runtime assertion pass
- 2. Recompute unbacked_bindings pass
- 3. Populate meta val when missing
- 4. Lift constants as placeholders
- 5. Replace raw autograd and autocast ops with HOPs
- 6. Prettify names for placeholders
- 7. Preserve requires_grad value on node meta val
- """
- # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
- # Overwrite output specs afterwards.
- flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs))
- gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature)
- # Simplify unbacked_bindings by recomputing them.
- # Useful for any pass that's interpreter-based and might call rebind_unbacked(),
- # e.g. AOTAutograd in this case.
- _replace_unbacked_bindings(gm)
- total_non_user_inputs = (
- len(graph_signature.parameters)
- + len(graph_signature.buffers)
- + len(graph_signature.input_tokens)
- )
- set_missing_meta_vals(gm, flat_fake_args, total_non_user_inputs)
- export_graph_signature: ExportGraphSignature | None
- export_graph_signature = _convert_to_export_graph_signature(
- graph_signature, gm, _get_non_persistent_buffers(mod)
- )
- # script objects are always stored in constants no matter whether they're initial inputs or
- # they're lifted in aot" before rewrite_script_object_meta
- constants = _materialize_and_lift_constants(
- gm, export_graph_signature, constant_attrs
- )
- if pre_dispatch:
- from torch._export.passes.replace_autocast_with_hop_pass import (
- replace_autocast_with_hop_pass,
- )
- from torch._export.passes.replace_set_grad_with_hop_pass import (
- replace_set_grad_with_hop_pass,
- )
- # Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because
- # a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass.
- # If replace_set_grad_with_hop_pass is before lift_constant_pass,
- # and the constant_tensor is passed as input of the set grad hop, the placeholder's
- # meta["val"] will be None and fails our verifier for placeholder.
- gm, export_graph_signature = replace_set_grad_with_hop_pass(
- gm, export_graph_signature
- )
- gm, export_graph_signature = replace_autocast_with_hop_pass(
- gm, export_graph_signature
- )
- # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
- for _mod in gm.modules():
- if not isinstance(_mod, torch.fx.GraphModule):
- continue
- for node in _mod.graph.nodes:
- if node.op in ["placeholder", "output"]:
- node.meta.pop("nn_module_stack", None)
- node.meta.pop("stack_trace", None)
- # Prettify names for placeholder nodes.
- if export_graph_signature is None:
- raise AssertionError("export_graph_signature must not be None")
- if _prettify_placeholder_names:
- placeholder_naming_pass(
- gm,
- export_graph_signature,
- mod,
- fake_args,
- fake_kwargs,
- fake_params_buffers,
- constants,
- )
- _preserve_requires_grad_pass(
- gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args
- )
- return ATenExportArtifact(
- gm,
- export_graph_signature,
- constants,
- inferred_out_spec=graph_signature.out_spec,
- )
- def _rename_constants_nodes(
- gm: torch.fx.GraphModule,
- graph_signature: ExportGraphSignature,
- ) -> None:
- """
- For strict mode, rename constants nodes that were previously annotated as buffers.
- """
- # handle name collisions with existing constants
- node_names = {node.name for node in gm.graph.nodes}
- def rename_constant(name):
- if name in node_names:
- n = 1
- while (dup_name := f"{name}_{n}") in node_names:
- n += 1
- # pyrefly: ignore [unbound-name]
- name = dup_name
- node_names.add(name)
- return name
- # use input specs to map names from buffers to constants
- buffer_prefix = placeholder_prefixes[InputKind.BUFFER]
- const_prefix = placeholder_prefixes[InputKind.CONSTANT_TENSOR]
- buffer_to_constant = {}
- for spec in graph_signature.input_specs:
- if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
- const_prefix
- ):
- if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
- c_name = rename_constant(
- const_prefix + spec.arg.name[len(buffer_prefix) :]
- )
- else: # lifted constant
- c_name = rename_constant(const_prefix + spec.arg.name)
- buffer_to_constant[spec.arg.name] = c_name
- spec.arg.name = c_name
- for spec in graph_signature.output_specs:
- if spec.arg.name in buffer_to_constant:
- spec.arg.name = buffer_to_constant[spec.arg.name]
- # Rename constants nodes for all modules
- for mod in gm.modules():
- if not isinstance(mod, torch.fx.GraphModule):
- continue
- for node in mod.graph.nodes:
- if node.name in buffer_to_constant:
- node.name = node.target = buffer_to_constant[node.name]
- mod.recompile()
- def _restore_state_dict(
- original_module: torch.nn.Module, traced_module: torch.fx.GraphModule
- ) -> None:
- """
- Restores the state dict of the traced module to that of the original module.
- """
- param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
- # Don't want to change the convention of previous call.
- param_buffer_table_reverse = {v: k for k, v in param_buffer_table.items()}
- # Replace state dict attr names with the fqn
- for name, _ in list(
- chain(
- original_module.named_parameters(remove_duplicate=False),
- # pyrefly: ignore [bad-argument-type]
- original_module.named_buffers(remove_duplicate=False),
- )
- ):
- if name in param_buffer_table_reverse:
- dynamo_name = param_buffer_table_reverse[name]
- param = torch.fx.graph_module._get_attr(traced_module, dynamo_name)
- torch.fx.graph_module._assign_attr(param, traced_module, name)
- torch.fx.graph_module._del_attr(traced_module, dynamo_name)
- # Replace graph getattr nodes with the correct name
- for node in traced_module.graph.nodes:
- if node.op == "get_attr":
- attr_name = node.target
- if attr_name in param_buffer_table:
- node.target = param_buffer_table[attr_name]
- traced_module.recompile()
- def _get_module_hierarchy(mod: torch.nn.Module) -> dict[str, str]:
- return {
- name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False)
- }
- def _make_module_call_graph(
- in_spec: TreeSpec,
- out_spec: TreeSpec,
- module_call_signatures: dict[str, ModuleCallSignature],
- forward_arg_names: list[str] | None = None,
- ) -> list[ModuleCallEntry]:
- original = [
- ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
- for fqn in _EXPORT_MODULE_HIERARCHY # type: ignore[union-attr]
- ]
- if original[0].fqn != "":
- raise AssertionError(
- f"expected first fqn to be empty string, got {original[0].fqn!r}"
- )
- original[0].signature = ModuleCallSignature(
- inputs=[],
- outputs=[],
- in_spec=in_spec,
- out_spec=out_spec,
- forward_arg_names=forward_arg_names,
- )
- additional = [
- ModuleCallEntry(fqn=fqn, signature=signature)
- for fqn, signature in module_call_signatures.items()
- if fqn not in _EXPORT_MODULE_HIERARCHY # type: ignore[operator]
- ]
- return [*original, *additional]
- class _ExportModuleSpecTrackerDict(dict):
- pass
- def _export_to_torch_ir(
- f: Callable,
- args: tuple[Any, ...],
- kwargs: dict[str, Any] | None = None,
- dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None,
- *,
- preserve_module_call_signature: tuple[str, ...] = (),
- disable_constraint_solver: bool = False,
- prefer_deferred_runtime_asserts_over_guards: bool = False,
- restore_fqn: bool = True,
- _log_export_usage: bool = True,
- same_signature: bool = True,
- ) -> torch.fx.GraphModule:
- """
- Traces either an nn.Module's forward function or just a callable with PyTorch
- operations inside and produce a torch.fx.GraphModule in torch IR.
- """
- if _log_export_usage:
- log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"})
- if not isinstance(args, tuple):
- raise UserError(
- UserErrorType.INVALID_INPUT,
- f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
- )
- kwargs = kwargs or {}
- # Map ints to a wrapper structure to help us mark it as dynamic, if it is
- # dynamic. We will unwrap ints in fakify later.
- args, kwargs = pytree.tree_map_only(int, _IntWrapper, (args, kwargs))
- combined_args = _combine_args(f, args, kwargs)
- _check_dynamic_shapes(combined_args, dynamic_shapes)
- constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
- # Unwrap static ints -- in the case where we have an empty graph
- # containing just integer computation, dynamo will run its generated
- # bytecode with these args/kwargs, which will error because we cannot
- # directly apply int operations on IntWrapper. So we will just unwrap
- # them here.
- args, kwargs = pytree.tree_map_only(
- _IntWrapper,
- lambda a: a.val
- if a.dynamism is None or a.dynamism.type == _DimHintType.STATIC
- else a,
- (args, kwargs),
- )
- dynamo_cfg = dataclasses.replace(
- DEFAULT_EXPORT_DYNAMO_CONFIG,
- prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
- )
- def use_legacy_dynamo_graph_capture() -> bool:
- return bool(
- constraints # dynamic shape
- or dynamic_shapes # dynamic shape
- or isinstance(f, torch.fx.GraphModule) # retracing
- or preserve_module_call_signature # unflatten
- or torch._functorch.config.fake_tensor_propagate_real_tensors # draft
- or torch._export.config.use_legacy_dynamo_graph_capture
- )
- with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)):
- try:
- module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = (
- _ExportModuleSpecTrackerDict()
- )
- ctx = nullcontext()
- if not isinstance(f, torch.fx.GraphModule):
- ctx = _wrap_submodules( # type: ignore[assignment]
- f, preserve_module_call_signature, module_call_specs
- )
- with ctx, _ignore_backend_decomps():
- if torch._export.config.use_new_tracer_experimental:
- from torch._dynamo.functional_export import (
- _dynamo_graph_capture_for_export,
- dynamo_graph_capture_for_export,
- )
- if use_legacy_dynamo_graph_capture():
- dynamo_graph_capture = _dynamo_graph_capture_for_export(
- f, constraints=constraints, dynamic_shapes=dynamic_shapes
- )
- else:
- dynamo_graph_capture = torch._dynamo.config.patch(
- replay_side_effects=False
- )(dynamo_graph_capture_for_export(f))
- # We can't serialize entire fake mode yet, so this is to make sure
- # things like copy.deepcopy(ep.graph_module) not crash.
- # see test_export.py::test_custom_tag_metadata_re_export
- # Once we delete the old strict export, we can use
- gm_torch_level = dynamo_graph_capture(*args, **kwargs)
- # We can't serialize entire fake mode yet, so this is to make sure
- # things like copy.deepcopy(ep.graph_module) not crash.
- # see test_export.py::test_custom_tag_metadata_re_export
- # Once we delete the old strict export, we can use this fake mode in the
- # subsequent logic when lowering to aten IR.
- del gm_torch_level.meta["fake_mode"]
- else:
- gm_torch_level, _ = torch._dynamo.export(
- f,
- dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
- constraints=constraints, # type: ignore[arg-type]
- assume_static_by_default=True,
- tracing_mode="symbolic",
- disable_constraint_solver=disable_constraint_solver,
- prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
- _log_export_usage=_log_export_usage,
- same_signature=same_signature,
- )(
- *args,
- **kwargs,
- )
- gm_torch_level.meta["module_call_specs"] = module_call_specs
- except (ConstraintViolationError, ValueRangeError) as e:
- raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
- except GuardOnDataDependentSymNode as e:
- raise UserError( # noqa: B904
- UserErrorType.ANTI_PATTERN,
- f"Consider annotating your code using torch._check*(). {str(e)}",
- case_name="constrain_as_size_example",
- )
- if isinstance(f, torch.nn.Module) and restore_fqn:
- _restore_state_dict(f, gm_torch_level)
- return gm_torch_level
- def _aot_export_joint_with_descriptors(
- stack,
- mod,
- args,
- *,
- kwargs,
- decompositions,
- fake_params_buffers,
- _record_nn_module_stack=True,
- ):
- from torch._functorch._aot_autograd.graph_compile import aot_stage2_export
- from torch._functorch._aot_autograd.input_output_analysis import (
- create_graph_signature,
- )
- joint_with_descriptors = aot_export_joint_with_descriptors(
- stack,
- mod,
- args,
- kwargs=kwargs,
- decompositions=decompositions,
- _record_nn_module_stack=_record_nn_module_stack,
- )
- # Convert JointWithDescriptors to graph module and ViewAndMutationMeta
- gm, fw_metadata = aot_stage2_export(
- joint_with_descriptors._aot_state,
- joint_with_descriptors._aot_graph_capture,
- )
- if not isinstance(gm, torch.fx.GraphModule):
- raise AssertionError(f"expected gm to be torch.fx.GraphModule, got {type(gm)}")
- # Create GraphSignature from the metadata
- graph_signature = create_graph_signature(
- gm,
- fw_metadata,
- joint_with_descriptors.in_spec,
- joint_with_descriptors.out_spec,
- user_args_flat=pytree.tree_leaves((args, kwargs)),
- params_and_buffers_flat=list(fake_params_buffers.values()),
- param_names=joint_with_descriptors.params_spec,
- buffer_names=joint_with_descriptors.buffers_spec,
- trace_joint=False,
- num_user_fw_outs=None,
- loss_index=None,
- )
- return gm, graph_signature
- def _export_to_aten_ir(
- mod: torch.nn.Module,
- fake_args,
- fake_kwargs,
- fake_params_buffers,
- constant_attrs: ConstantAttrMap,
- produce_guards_callback=None,
- *,
- transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later.
- pre_dispatch=False,
- decomp_table=None,
- _prettify_placeholder_names: bool = True,
- decompose_custom_triton_ops: bool = False,
- ) -> ATenExportArtifact:
- custom_triton_ops_decomposition_ctx = (
- nullcontext
- if decompose_custom_triton_ops
- else _disable_custom_triton_op_functional_decomposition
- )
- # This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode,
- # otherwise aot_export_module will error out because it sees a mix of fake_modes.
- # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
- with ExitStack() as stack:
- stack.enter_context(
- torch.nn.utils.stateless._reparametrize_module(
- mod,
- fake_params_buffers,
- tie_weights=True,
- strict=True,
- stack_weights=True,
- )
- )
- stack.enter_context(_ignore_backend_decomps())
- stack.enter_context(_compiling_state_context())
- stack.enter_context(custom_triton_ops_decomposition_ctx())
- stack.enter_context(torch.no_grad())
- gm, graph_signature = transform(_aot_export_joint_with_descriptors)(
- stack,
- mod,
- fake_args,
- kwargs=fake_kwargs,
- decompositions=decomp_table,
- fake_params_buffers=fake_params_buffers,
- _record_nn_module_stack=True,
- )
- def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm):
- if isinstance(old_gm, torch.fx.GraphModule):
- if hasattr(old_gm, "meta"):
- new_gm.meta.update(old_gm.meta)
- old_output_node = list(old_gm.graph.nodes)[-1]
- new_output_node = list(new_gm.graph.nodes)[-1]
- if old_output_node.op != "output" or new_output_node.op != "output":
- raise AssertionError(
- f"expected both output nodes to have op='output', got old={old_output_node.op!r}, new={new_output_node.op!r}"
- )
- # make sure we don't override any meta
- if "desc" in new_output_node.meta:
- del new_output_node.meta["desc"]
- new_output_node.meta.update(old_output_node.meta)
- # TODO unfortunately preserving graph-level metadata and output node's meta
- # is not working well with aot_export. So we manually copy it.
- # (The node-level meta is addressed above.)
- _maybe_fixup_gm_and_output_node_meta(mod, gm)
- # Run produce guards before we handle runtime asserts.
- # This means we run the export solver before the runtime asserts pass.
- # Right now this doesn't mean much - the export solver is only there for suggested fixes,
- # and we won't even get to constraint solving if that's needed.
- # But if in future we want to control what runtime asserts are emitted for export,
- # or rely on produce_guards + solver for some simplification on runtime asserts, this probably makes sense.
- if produce_guards_callback:
- try:
- produce_guards_callback(gm)
- except (ConstraintViolationError, ValueRangeError) as e:
- raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
- return _produce_aten_artifact(
- gm=gm,
- mod=mod,
- constant_attrs=constant_attrs,
- graph_signature=graph_signature,
- pre_dispatch=pre_dispatch,
- fake_args=fake_args,
- fake_kwargs=fake_kwargs,
- fake_params_buffers=fake_params_buffers,
- _prettify_placeholder_names=_prettify_placeholder_names,
- )
- def _get_forward_arg_names(
- mod: torch.nn.Module,
- args: tuple[Any, ...],
- kwargs: dict[str, Any] | None = None,
- ) -> list[str]:
- """
- Gets the argument names to forward that are used, for restoring the
- original signature when unlifting the exported program module.
- - Positional args: retain the original argument names, and enumerate
- *args as args_0, args_1, ...
- - Keyword args: retain the original kwarg names in the order specified
- by the user. This order seems to matter for the current state of
- export lifted modules.
- """
- sig = inspect.signature(mod.forward)
- _args = sig.bind_partial(*args).arguments
- names: list[str] = []
- for name, value in _args.items():
- # handle variable number of positional args
- if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL:
- names.extend([f"{name}_{i}" for i, _ in enumerate(value)])
- else:
- names.append(name)
- # order of kwargs matters for input spec
- if kwargs:
- names.extend([kwarg for kwarg, _ in kwargs.items()])
- return names
- def _get_non_persistent_buffers(mod: torch.nn.Module) -> set[str]:
- """
- Returns set of non-persistent buffers in a module and its submodules.
- """
- result: set[str] = set()
- for name, m in mod.named_modules(remove_duplicate=False):
- if name:
- result.update(f"{name}.{b}" for b in m._non_persistent_buffers_set)
- else:
- result.update(m._non_persistent_buffers_set)
- return result
- def _rewrite_dynamo_tensor_constants(
- orig_mod_buffers: set[torch.Tensor],
- traced_mod_buffers: dict[str, torch.Tensor],
- graph_signature: ExportGraphSignature,
- constants: dict[str, _ConstantAttributeType],
- ) -> None:
- """
- Dynamo erroneously marks tensor attributes on modules as buffers.
- Rewrite them to be tensor constants.
- """
- for spec in graph_signature.input_specs:
- if spec.kind == InputKind.BUFFER:
- if spec.target is None:
- raise AssertionError("spec.target must not be None for BUFFER kind")
- value = traced_mod_buffers[spec.target]
- if value not in orig_mod_buffers:
- # This was a tensor constant erroneously marked as a buffer.
- # Convert it into a constant in the graph signature, and add its
- # value to the constants table.
- spec.kind = InputKind.CONSTANT_TENSOR
- constants[spec.target] = value # type: ignore[arg-type]
- def _move_non_persistent_buffers_to_tensor_constants(
- orig_mod: torch.nn.Module,
- graph_signature: ExportGraphSignature,
- constants: dict[str, _ConstantAttributeType],
- ) -> None:
- """
- Moves non-persistent buffers to tensor constants.
- """
- for spec in graph_signature.input_specs:
- if spec.kind == InputKind.BUFFER and not spec.persistent:
- if spec.target is None:
- raise AssertionError(
- "spec.target must not be None for non-persistent BUFFER kind"
- )
- if spec.target in constants:
- raise AssertionError(
- f"spec.target {spec.target!r} should not already be in constants"
- )
- constants[spec.target] = orig_mod.get_buffer(spec.target) # type: ignore[arg-type]
- def _verify_nn_module_stack(graph_module: torch.fx.GraphModule) -> None:
- """
- Perform nn_module_stack checks on the graph.
- Current constraints:
- For the top level graph:
- - populated for 'call_function', 'get_attr'
- - None for 'placeholder', 'output'
- For submodule graphs:
- - None for 'placeholder', output'
- TODO(pianpwk): make this a consistent node-level check once nn_module_stack is populated for cond submodules.
- """
- # Check top-level graph for all nodes, all graphs for placeholder & output nodes
- for i, mod in enumerate([graph_module] + list(graph_module.modules())):
- if not isinstance(mod, torch.fx.GraphModule):
- continue
- for node in mod.graph.nodes:
- if node.op in ["call_function", "get_attr"]:
- if i == 0:
- if (
- nn_module_stack := node.meta.get("nn_module_stack", None)
- ) is None:
- raise SpecViolationError(
- f"Node {node} of type {node.op} is missing nn_module_stack metadata"
- )
- if not all(
- isinstance(k, str)
- and isinstance(v, tuple)
- and len(v) == 2
- and all(isinstance(x, str) for x in v)
- for k, v in nn_module_stack.items()
- ):
- raise SpecViolationError(
- f"Node {node} of type {node.op} has incorrect nn_module_stack metadata format"
- f"expected Dict[str, Tuple[str, str]], but got {nn_module_stack}"
- )
- elif node.op in ["placeholder", "output"]:
- if node.meta.get("nn_module_stack", None):
- raise SpecViolationError(
- f"Node {node} of type {node.op} contains nn_module_stack metadata, this should be None"
- )
- def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None:
- """
- Perform stack trace checks on the graph.
- Constraints:
- - None or non-empty str for 'call_function', 'get_attr'
- - None for 'placeholder', 'output'
- """
- for mod in [graph_module, *graph_module.modules()]:
- if not isinstance(mod, torch.fx.GraphModule):
- continue
- for node in graph_module.graph.nodes:
- stack_trace = node.meta.get("stack_trace", None)
- if node.op in ["call_function", "get_attr"]:
- if not (stack_trace is None or isinstance(stack_trace, str)):
- raise SpecViolationError(
- f"Node {node} of type {node.op} has invalid stack_trace metadata, "
- f"expected a string or None but instead found: {stack_trace}"
- )
- elif node.op in ["placeholder", "output"]:
- if stack_trace:
- raise SpecViolationError(
- f"Node {node} of type {node.op} contains stack_trace metadata, "
- f"expected None but instead found: {stack_trace}"
- )
- def _verify_placeholder_names(
- gm: torch.fx.GraphModule, sig: ExportGraphSignature
- ) -> None:
- """
- Performs a sanity check on the placeholder node names.
- - User input nodes: no restrictions, should match the original forward() signature
- - Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in <placeholder_prefixes>
- """
- name_to_kind = {spec.arg.name: spec.kind for spec in sig.input_specs}
- for mod in gm.modules():
- if not isinstance(mod, torch.fx.GraphModule):
- continue
- for node in mod.graph.nodes:
- if node.op == "placeholder":
- if node.name not in name_to_kind:
- continue
- node_kind = name_to_kind[node.name]
- prefix = placeholder_prefixes[node_kind]
- if not node.name.startswith(prefix):
- raise SpecViolationError(
- f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}"
- )
- def get_ep_stats(ep: ExportedProgram) -> dict[str, Any]:
- op_count = 0
- op_set = set()
- for m in ep.graph_module.modules():
- if not isinstance(m, torch.fx.GraphModule):
- continue
- for node in m.graph.nodes:
- if node.op != "call_function":
- continue
- op_count += 1
- if not hasattr(node.target, "__module__"):
- raise AssertionError(
- f"node.target {node.target} must have __module__ attribute"
- )
- if not hasattr(node.target, "__name__"):
- raise AssertionError(
- f"node.target {node.target} must have __name__ attribute"
- )
- op_set.add(f"{node.target.__module__}.{node.target.__name__}")
- return {"op_count": op_count, "op_set": op_set}
- _EXPORT_FLAGS: set[str] | None = None
- _EXPORT_MODULE_HIERARCHY: dict[str, str] | None = None
- def _log_export_wrapper(fn):
- @functools.wraps(fn)
- def wrapper(*args, **kwargs):
- global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY
- try:
- start = time.time()
- ep = fn(*args, **kwargs)
- end = time.time()
- log_export_usage(
- event="export.time",
- metrics=end - start,
- flags=_EXPORT_FLAGS,
- **get_ep_stats(ep),
- )
- except Exception as e:
- t = type(e)
- error_type = t.__module__ + "." + t.__qualname__
- case_name = get_class_if_classified_error(e)
- if case_name is not None:
- log.error(exportdb_error_message(case_name))
- log_export_usage(
- event="export.error.classified",
- type=error_type,
- message=str(e),
- flags=_EXPORT_FLAGS,
- )
- else:
- log_export_usage(
- event="export.error.unclassified",
- type=error_type,
- message=str(e),
- flags=_EXPORT_FLAGS,
- )
- if hasattr(e, "partial_fx_graph"):
- print(
- e.partial_fx_graph,
- file=sys.stderr,
- )
- raise e
- finally:
- _EXPORT_FLAGS = None
- _EXPORT_MODULE_HIERARCHY = None
- return ep
- return wrapper
- def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs):
- if not isinstance(example_inputs, (tuple, list, dict)):
- example_inputs = (example_inputs,)
- elif isinstance(example_inputs, list):
- example_inputs = tuple(example_inputs)
- elif (
- isinstance(example_inputs, (torch.Tensor, dict))
- and example_kwarg_inputs is None
- ):
- example_inputs = (example_inputs,)
- if example_kwarg_inputs is None:
- example_kwarg_inputs = {}
- return example_inputs, example_kwarg_inputs
- def _get_original_state_dict(mod: torch.nn.Module) -> dict[str, Any]:
- # Explicitly not calling mode.state_dict() as we do not want the module state for serialization
- # but the running module state so we can always match by id() the entries here with the graph inputs
- named_parameters = dict(mod.named_parameters(remove_duplicate=False))
- named_buffers = dict(mod.named_buffers(remove_duplicate=False))
- original_state_dict = named_parameters | named_buffers
- non_persistent_buffers = _get_non_persistent_buffers(mod)
- for k in non_persistent_buffers:
- original_state_dict.pop(k, None)
- return original_state_dict
- def _process_export_inputs(
- mod: torch.nn.Module,
- args: tuple[object, ...],
- kwargs: dict[str, object] | None,
- dynamic_shapes: _DynamicShapesSpec
- | torch.export.AdditionalInputs
- | torch.export.ShapesCollection
- | None,
- ) -> tuple[
- tuple[object, ...],
- dict[str, object],
- TreeSpec,
- _DynamicShapesSpec | None,
- Callable[[ExportedProgram], None],
- ]:
- """
- Process and validate export inputs for the torch.export API.
- This function validates the input arguments, normalizes kwargs, computes input tree specs,
- and handles special dynamic shapes cases like AdditionalInputs and ShapesCollection.
- Args:
- mod: The PyTorch module to be exported.
- args: Tuple of example positional inputs for the module.
- kwargs: Optional dictionary of example keyword inputs.
- dynamic_shapes: Optional specification for dynamic shapes. Can be:
- - dict mapping argument names to dynamic shape specifications
- - tuple/list specifying dynamic shapes for each input in order
- - torch.export.AdditionalInputs object with verification callback
- - torch.export.ShapesCollection object
- Returns:
- A tuple containing:
- - args: Validated tuple of positional inputs
- - kwargs: Normalized dictionary of keyword inputs (empty dict if None was passed)
- - original_in_spec: TreeSpec representing the flattened input structure
- - dynamic_shapes: Processed dynamic shapes specification
- - verify_additional_inputs: Callback function for additional input verification
- Raises:
- UserError: If args is not a tuple.
- """
- if not isinstance(args, tuple):
- raise UserError(
- UserErrorType.INVALID_INPUT,
- f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
- )
- kwargs = kwargs if kwargs is not None else {}
- if pytree.is_namedtuple_instance(args):
- args = tuple(args)
- _, original_in_spec = pytree.tree_flatten((args, kwargs))
- verify_additional_inputs: Callable[[ExportedProgram], None]
- out_dynamic_shapes: _DynamicShapesSpec | None
- if isinstance(dynamic_shapes, torch.export.AdditionalInputs):
- verify_additional_inputs = dynamic_shapes.verify # type: ignore[assignment]
- out_dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) # type: ignore[assignment]
- else:
- verify_additional_inputs = lambda ep: None # noqa: E731
- if isinstance(dynamic_shapes, torch.export.ShapesCollection):
- out_dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) # type: ignore[assignment]
- else:
- out_dynamic_shapes = dynamic_shapes
- return args, kwargs, original_in_spec, out_dynamic_shapes, verify_additional_inputs
- def _get_module_call_graph(
- export_artifact: ExportArtifact,
- preserve_module_call_signature: tuple[str, ...],
- strict_mode_export: bool,
- forward_arg_names: list[str] | None = None,
- ) -> tuple[torch.fx.GraphModule, list[ModuleCallEntry]]:
- """
- In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and
- return module_call_graph.
- """
- gm: torch.fx.GraphModule = export_artifact.aten.gm
- export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
- module_call_specs: dict[str, dict[str, TreeSpec]] = (
- export_artifact.module_call_specs
- )
- in_spec: TreeSpec = export_artifact.in_spec
- out_spec: TreeSpec = export_artifact.out_spec
- # Make module signatures.
- module_call_signatures: dict[str, ModuleCallSignature] = {}
- for fqn, specs in module_call_specs.items():
- mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn
- module_call_signatures[mod_fqn] = ModuleCallSignature(
- inputs=[],
- outputs=[],
- in_spec=specs["in_spec"],
- out_spec=specs["out_spec"],
- forward_arg_names=None, # we only propagate forward_arg_names for the top level module
- )
- if len(preserve_module_call_signature) > 0:
- if not strict_mode_export:
- _rewrite_tracepoint_node(gm)
- res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm)
- if res is None:
- raise AssertionError("CollectTracepointsPass returned None")
- gm = res.graph_module
- if _EXPORT_MODULE_HIERARCHY is None:
- raise AssertionError("_EXPORT_MODULE_HIERARCHY must not be None")
- module_call_graph = _make_module_call_graph(
- in_spec,
- out_spec,
- module_call_signatures,
- forward_arg_names,
- )
- return gm, module_call_graph
- def _get_range_constraints(
- mod: torch.nn.Module,
- export_artifact: ExportArtifact,
- args,
- kwargs,
- dynamic_shapes,
- ):
- gm: torch.fx.GraphModule = export_artifact.aten.gm
- export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
- fake_mode: FakeTensorMode = export_artifact.fake_mode
- num_lifted = next(
- (
- i
- for i, s in enumerate(export_graph_signature.input_specs)
- if s.kind == InputKind.USER_INPUT
- ),
- len(export_graph_signature.input_specs),
- )
- combined_args = _combine_args(mod, args, kwargs)
- # This is because we trace based on the kwargs passed in from user
- # not based on the signature. I feel it would be better to just enforce
- # one ordering at the start of tracing to avoid confusions, but that is
- # bigger refactor, so do this to unblock for now.
- combined_args_traced_order = {}
- for arg in combined_args:
- if arg not in kwargs:
- combined_args_traced_order[arg] = combined_args[arg]
- for key in kwargs:
- combined_args_traced_order[key] = kwargs[key]
- combined_args = combined_args_traced_order
- range_constraints = make_constraints(
- fake_mode,
- gm,
- combined_args,
- dynamic_shapes,
- num_lifted,
- )
- return range_constraints
- def _get_inline_constraints(fake_mode: FakeTensorMode):
- if fake_mode.shape_env is None:
- raise AssertionError("fake_mode.shape_env must not be None")
- return {
- k: v
- for k, v in fake_mode.shape_env.var_to_range.items()
- if free_unbacked_symbols(k)
- }
- @contextmanager
- def patch_forward(obj: torch.nn.Module, new_method):
- """Helper method to make it easier to cleanly torch.export() a method on a
- module that is not `forward`.
- """
- # Save the original method
- original_method = obj.forward
- # Patch the method
- obj.forward = new_method.__get__(obj, obj.__class__)
- try:
- yield
- finally:
- # Restore the original method
- obj.forward = original_method
- @contextmanager
- def _temp_disable_texpr_fuser():
- original_state = torch._C._jit_texpr_fuser_enabled()
- torch._C._jit_set_texpr_fuser_enabled(False)
- try:
- yield
- finally:
- torch._C._jit_set_texpr_fuser_enabled(original_state)
- def _strict_export(
- mod: torch.nn.Module,
- args: tuple[Any, ...],
- kwargs: dict[str, Any],
- dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None,
- preserve_module_call_signature: tuple[str, ...],
- orig_in_spec: TreeSpec,
- prefer_deferred_runtime_asserts_over_guards: bool,
- _to_aten_func: Callable,
- ) -> ExportArtifact:
- """
- _to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
- """
- gm_torch_level = _export_to_torch_ir(
- # pyrefly: ignore [bad-argument-type]
- mod,
- args,
- kwargs,
- dynamic_shapes,
- preserve_module_call_signature=preserve_module_call_signature,
- restore_fqn=False, # don't need to restore because we will do it later
- prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
- _log_export_usage=False,
- )
- # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
- (
- fake_args,
- fake_kwargs,
- dynamo_fake_mode,
- ) = _extract_fake_inputs(gm_torch_level, args, kwargs)
- fake_params_buffers = _fakify_params_buffers(dynamo_fake_mode, gm_torch_level)
- # First, we want to pass through the graph to try populating
- # val field for getattr if there is anything missing.
- # This can happen when quantization adds extra params and forgets
- # to update "val"
- for node in gm_torch_level.graph.nodes:
- if node.op == "get_attr" and "val" not in node.meta:
- attr = getattr(gm_torch_level, node.target)
- # Checks if it is not a HigherOrderOp branch or a module
- if not isinstance(attr, torch.nn.Module):
- if dynamo_fake_mode is None:
- raise AssertionError(
- "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
- )
- node.meta["val"] = dynamo_fake_mode.from_tensor(
- attr, static_shapes=True
- )
- # Fix the graph output signature to be tuple if scalar
- wrap_tuple = False
- # Calling gm_torch_level._out_spec is not safe because gm_torch_level might be
- # a _LazyGraphModule, which does not populate _out_spec when calling recompile().
- # TODO: Fix recompile() in _LazyGraphModule. T207713214
- if isinstance(gm_torch_level.graph._codegen, torch.fx.graph._PyTreeCodeGen):
- out_spec = orig_out_spec = gm_torch_level.graph._codegen.pytree_info.out_spec
- orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
- # Used to get rid of lint type error.
- if out_spec is None:
- raise AssertionError("out_spec must not be None")
- if out_spec.type not in (list, tuple):
- # aot_export expect the return type to always be a tuple.
- out_spec = pytree.treespec_tuple([out_spec])
- wrap_tuple = True
- gm_torch_level.graph._codegen.pytree_info = _PyTreeInfo(
- orig_arg_names,
- gm_torch_level._in_spec,
- out_spec,
- )
- elif isinstance(
- gm_torch_level.graph._codegen,
- torch._dynamo.functional_export._DynamoBytecodeCodeGen,
- ):
- # Since we're using bytecode codegen, we need to separately apply tuple
- # output instead of modifying pytree spec inplace.
- orig_arg_names = gm_torch_level.graph._codegen.orig_arg_names
- out_spec = orig_out_spec = None
- wrap_tuple = gm_torch_level.graph._codegen.wrap_tuple = True
- else:
- raise RuntimeError(f"Unknown codegen type: {gm_torch_level.graph._codegen}")
- gm_torch_level.recompile()
- _normalize_nn_module_stack(gm_torch_level, type(mod))
- params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level)
- # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace)
- # from the param nodes as they are treated as fresh inputs
- # Therefore, we manually extract them before calling into aot_export
- # params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level)
- constant_attrs = _gather_constant_attrs(mod)
- param_buffer_table: dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
- # Dynamo does not track which buffers were registered as non-persistent. This info
- # is available in the original module, so we transfer it to the traced module. Also,
- # since we didn't restore original param/buffer names yet, we must use traced names.
- non_persistent_buffers = _get_non_persistent_buffers(mod)
- reverse_name_lookup = {orig: traced for traced, orig in param_buffer_table.items()}
- gm_torch_level._non_persistent_buffers_set = {
- reverse_name_lookup[name]
- for name in non_persistent_buffers
- if name in reverse_name_lookup
- }
- tx = TracingContext(dynamo_fake_mode)
- with (
- dynamo_fake_mode,
- tracing(tx),
- mock.patch.object(dynamo_fake_mode, "allow_non_fake_inputs", True),
- ):
- aten_export_artifact = _to_aten_func(
- gm_torch_level,
- # NOTE: graph module expects only positional args
- _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs),
- {},
- fake_params_buffers,
- constant_attrs,
- )
- # Decompose for readability.
- gm = aten_export_artifact.gm
- export_graph_signature = aten_export_artifact.sig
- constants = aten_export_artifact.constants
- _populate_param_buffer_metadata_to_new_gm(
- params_buffers_to_node_meta, gm, export_graph_signature
- )
- # Do some cleanups on the graph module to restore the state dict to the
- # expected form. Each of these steps should probably get fixed upstream.
- # 1. Remove tensor constants that were added as buffers.
- _rewrite_dynamo_tensor_constants(
- orig_mod_buffers=set(mod.buffers()),
- traced_mod_buffers=dict(gm_torch_level.named_buffers()),
- graph_signature=export_graph_signature,
- constants=constants,
- )
- # 2. Restore FQN of param/buffers
- _replace_param_buffer_names(param_buffer_table, export_graph_signature)
- # 3. Move non-persistent buffers to tensor constants
- _move_non_persistent_buffers_to_tensor_constants(
- mod, export_graph_signature, constants
- )
- # 4. Rewrite constants to have the same FQN as the original module.
- _remap_constants(constant_attrs, export_graph_signature, constants)
- # 5. Rename constants nodes in graph module from buffers to constants
- _rename_constants_nodes(gm, export_graph_signature)
- if orig_out_spec is None:
- out_spec = aten_export_artifact.inferred_out_spec
- if wrap_tuple:
- out_spec = out_spec.children()[0]
- else:
- out_spec = orig_out_spec
- return ExportArtifact(
- aten=aten_export_artifact,
- in_spec=orig_in_spec,
- out_spec=out_spec,
- fake_mode=dynamo_fake_mode,
- module_call_specs=gm_torch_level.meta["module_call_specs"],
- )
- def _export_to_aten_ir_make_fx(
- mod: torch.nn.Module,
- fake_args,
- fake_kwargs,
- fake_params_buffers,
- constant_attrs: ConstantAttrMap,
- produce_guards_callback=None,
- transform=lambda x: x,
- ) -> ATenExportArtifact:
- def _make_fx_helper(stack, mod, args, kwargs, **flags):
- kwargs = kwargs or {}
- named_parameters = dict(mod.named_parameters(remove_duplicate=False))
- named_buffers = dict(mod.named_buffers(remove_duplicate=False))
- params_and_buffers = {**named_parameters, **named_buffers}
- params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)
- params_and_buffers_flat = tuple(params_and_buffers_flat)
- param_len = len(named_parameters)
- buffer_len = len(named_buffers)
- params_len = len(params_and_buffers)
- functional_call = create_functional_call(
- mod, params_spec, params_len, store_orig_mod=True
- )
- params_buffers_args: list[Any] = []
- params_buffers_args.extend(params_and_buffers_flat)
- params_buffers_args.extend(args)
- flat_fn, out_spec = create_tree_flattened_fn(
- functional_call, params_buffers_args, kwargs
- )
- flat_args, in_spec = pytree.tree_flatten((params_buffers_args, kwargs))
- @functools.wraps(flat_fn)
- def wrapped_fn(*args):
- return tuple(flat_fn(*args))
- with enable_python_dispatcher():
- ctx = nullcontext()
- non_strict_root = getattr(mod, "_export_root", None)
- if non_strict_root is not None:
- ctx = _detect_attribute_assignment(non_strict_root) # type: ignore[assignment]
- # For any buffer that is assigned, we want to associate it to the final proxy node
- # that it is assigned to. This node can then be copied into the buffer.
- assigned_buffers: dict[str, str] = {}
- hook = register_buffer_assignment_hook(
- non_strict_root, assigned_buffers
- )
- def custom_getattribute(self, attr, *, original_getattr, attrs_to_proxy):
- """
- The idea here is that we override subclass getattr methods to proxy
- inner tensors and metadata. Because of infinite loop shenanigans, we have
- to manually construct the getattr proxy nodes without relying on torch function
- system.
- """
- out = original_getattr(self, attr)
- if attr in attrs_to_proxy:
- if torch._C._is_torch_function_mode_enabled():
- if isinstance(out, torch.Tensor):
- # When we get here there is no guarantee that we will hit the
- # PreDispatchTorchFunctionMode, so we manually peak into the torch
- # function mode list and tweak the PreDispatchTorchFunctionMode.
- # This has side effect of proxying stuff like
- # proxy.node.meta["val"] = extract_val(val) because at that time, torch function
- # mode is still active. It seems bad to turn it off inside proxy_tensor.py, so
- # I guess we will just rely on DCE for now to remove extra stuff like detach
- torch_function_mode_stack = (
- torch.overrides._get_current_function_mode_stack()
- )
- for mode in torch_function_mode_stack:
- if isinstance(mode, PreDispatchTorchFunctionMode):
- tracer = mode.tracer
- proxy = get_proxy_slot(self, tracer).proxy
- inner_proxy = tracer.create_proxy(
- "call_function",
- torch.ops.export.access_subclass_inner_tensor.default,
- (proxy, attr),
- {},
- )
- track_tensor_tree(
- out, inner_proxy, constant=None, tracer=tracer
- )
- return out
- @contextmanager
- def override_getattribute_for_subclasses(args):
- """
- Context manager that temporarily monkey patches
- tensor.__getattribute__ so that we can intercept it at
- torch_function layer.
- """
- # Dictionary that tracks subclass type to original getattr function
- # and the attributes we can proxy.
- tensor_type_to_old_getattribute: dict[
- type[torch.Tensor], tuple[Callable, set[str]]
- ] = {}
- for arg in args:
- subclass_types_to_instances: dict[
- type[torch.Tensor], list[type[torch.Tensor]]
- ] = get_subclass_typing_container(arg)
- for subclass_type in subclass_types_to_instances:
- if subclass_type not in tensor_type_to_old_getattribute:
- if len(subclass_types_to_instances[subclass_type]) == 0:
- raise AssertionError(
- f"subclass_types_to_instances[{subclass_type}] must not be empty"
- )
- instance = subclass_types_to_instances[subclass_type][0]
- # Query subclass specific attrs
- attrs_to_proxy = set(dir(instance)) - set(dir(torch.Tensor))
- tensor_type_to_old_getattribute[subclass_type] = (
- subclass_type.__getattribute__, # type: ignore[attr-defined]
- attrs_to_proxy,
- )
- try:
- for k, (
- old_getattr,
- attrs_to_proxy,
- ) in tensor_type_to_old_getattribute.items():
- custom = functools.partialmethod(
- custom_getattribute,
- original_getattr=old_getattr,
- attrs_to_proxy=attrs_to_proxy,
- )
- k.__getattribute__ = custom # type: ignore[assignment, attr-defined]
- yield
- finally:
- for k, (old_getattr, _) in tensor_type_to_old_getattribute.items():
- k.__getattribute__ = old_getattr # type: ignore[method-assign, attr-defined]
- @contextmanager
- def _maybe_restore_grad_state():
- """
- When pre-dispatch export accidentally change grad state, we restore it back.
- This can happen when we are calling torch._C._set_grad_enabled directly in the
- forward.
- """
- old_state = torch.is_grad_enabled()
- try:
- yield
- finally:
- torch._C._set_grad_enabled(old_state)
- with (
- ctx,
- override_getattribute_for_subclasses(flat_args),
- _maybe_restore_grad_state(),
- ):
- gm = make_fx(
- wrapped_fn,
- record_module_stack=True,
- pre_dispatch=True,
- )(*flat_args)
- if non_strict_root is not None:
- input_names = _graph_input_names(gm)
- buffer_input_names = {
- name: input_names[param_len + i]
- for i, (name, buf) in enumerate(non_strict_root._buffers.items())
- if buf is not None
- }
- output_node = list(gm.graph.nodes)[-1]
- # We copy nodes corresponding to buffer assignments to buffers in the graph.
- for buf, name in assigned_buffers.items(): # type: ignore[possibly-undefined]
- buf_node = _find_node(gm, buffer_input_names[buf])
- name_node = _find_node(gm, name)
- with gm.graph.inserting_before(output_node):
- new_node = gm.graph.create_node(
- "call_function",
- torch.ops.aten.copy_.default,
- args=(buf_node, name_node),
- )
- new_node.meta = name_node.meta
- hook.remove() # type: ignore[possibly-undefined]
- def _is_impure(node):
- if node.op == "call_function" and node.target in (
- # In export, we ignore any op that is related to
- # eager mode profiling call. The expectation is
- # that either runtimes provide their own profiling
- # OR user wrap the compiled region on a profiling in
- # later stage.
- torch.ops.profiler._record_function_enter.default,
- torch.ops.profiler._record_function_enter_new.default,
- torch.ops.profiler._record_function_exit._RecordFunction,
- # In theory, we could fix this dead detach and getattr nodes
- # from subclass tensors if we carefully rewrite track_tensor_tree
- # in a way that it doesn't do any tensor methods.
- torch.ops.aten.detach.default,
- torch.ops.export.access_subclass_inner_tensor.default,
- ):
- return False
- return True
- gm.graph.eliminate_dead_code(_is_impure)
- # create graph signature
- if out_spec.spec is None:
- raise AssertionError("out_spec.spec is None!")
- input_names = _graph_input_names(gm)
- output_names = _graph_output_names(gm)
- sig = GraphSignature(
- parameters=list(named_parameters),
- buffers=list(named_buffers),
- # pyrefly: ignore[bad-argument-type]
- user_inputs=input_names[params_len:],
- user_outputs=output_names,
- # pyrefly: ignore[no-matching-overload]
- inputs_to_parameters=dict(zip(input_names[0:param_len], named_parameters)),
- # pyrefly: ignore[no-matching-overload]
- inputs_to_buffers=dict(
- zip(input_names[param_len : param_len + buffer_len], named_buffers)
- ),
- buffers_to_mutate={},
- parameters_to_mutate={},
- user_inputs_to_mutate={},
- in_spec=in_spec,
- out_spec=out_spec.spec,
- backward_signature=None,
- input_tokens=[],
- output_tokens=[],
- )
- return gm, sig
- # This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode,
- # otherwise aot_export_module will error out because it sees a mix of fake_modes.
- # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
- with ExitStack() as stack:
- stack.enter_context(
- torch.nn.utils.stateless._reparametrize_module(
- mod,
- fake_params_buffers,
- tie_weights=True,
- strict=True,
- stack_weights=True,
- )
- )
- stack.enter_context(_ignore_backend_decomps())
- stack.enter_context(_compiling_state_context())
- gm, graph_signature = transform(_make_fx_helper)(
- stack,
- mod,
- fake_args,
- trace_joint=False,
- kwargs=fake_kwargs,
- )
- # [NOTE] In training IR, we don't run
- # any DCE as a result we preserve constant
- # nodes in the graph. make_fx invariant is that
- # they don't guarantee every node gets a meta['val']
- # field. Since the actual value is already hardcoded in
- # graph, the node.meta here actually doesn't matter. But
- # we do this to make spec verifier happy.
- for node in gm.graph.nodes:
- if (
- node.op == "call_function"
- and len(node.users) == 0
- and "val" not in node.meta
- ):
- node.meta["val"] = None
- if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
- gm.meta.update(mod.meta)
- # See comment in _export_to_aten_ir()
- if produce_guards_callback:
- try:
- produce_guards_callback(gm)
- except (ConstraintViolationError, ValueRangeError) as e:
- raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
- return _produce_aten_artifact(
- gm=gm,
- mod=mod,
- constant_attrs=constant_attrs,
- graph_signature=graph_signature,
- pre_dispatch=True,
- fake_args=fake_args,
- fake_kwargs=fake_kwargs,
- fake_params_buffers=fake_params_buffers,
- )
- def set_missing_meta_vals(gm, flat_args, num_params_buffers):
- # Sets missing metadata to address two problems:
- # 1. aot_export adds symint metadata for placeholders with int values; since
- # these become specialized, we replace such metadata with the original values.
- # 2. any tensor attributes that are not params / buffers, i.e., are constants
- # need to have their metadata set before lifting them because it is needed
- # for computing the exported program's signature.
- index = 0
- for node in gm.graph.nodes:
- if node.op == "placeholder":
- if index >= num_params_buffers:
- user_arg = flat_args[index - num_params_buffers]
- if not isinstance(user_arg, torch.Tensor):
- node.meta["val"] = user_arg
- index += 1
- def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
- return next(iter(node for node in gm.graph.nodes if node.name == name))
- def _non_strict_export(
- mod: torch.nn.Module,
- args: tuple[Any, ...],
- kwargs: dict[str, Any],
- dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None,
- preserve_module_call_signature: tuple[str, ...],
- orig_in_spec: TreeSpec,
- prefer_deferred_runtime_asserts_over_guards: bool,
- _to_aten_func: Callable,
- ) -> ExportArtifact:
- """
- _to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir`
- """
- out_spec: TreeSpec | None = None
- in_spec: TreeSpec | None = None
- module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {}
- def _tuplify_outputs(aot_export):
- def _aot_export_non_strict(stack, mod, args, *, kwargs=None, **flags):
- kwargs = kwargs or {}
- class Wrapper(torch.nn.Module):
- def __init__(self, mod):
- super().__init__()
- self._export_root = mod
- def forward(self, *args, **kwargs):
- nonlocal out_spec
- nonlocal in_spec
- mod = self._export_root
- _, in_spec = pytree.tree_flatten((args, kwargs))
- if isinstance(mod, torch.fx.GraphModule):
- # NOTE: We're going to run this graph module with an fx interpreter,
- # which will not run any forward hooks. Thus, ideally, we should run
- # all forward hooks here. But the general logic for running them is
- # complicated (see nn/module.py), and probably not worth duplicating.
- # Instead we only look for, and run, an export-specific forward hook.
- if (
- _check_input_constraints_pre_hook
- in mod._forward_pre_hooks.values()
- ):
- _check_input_constraints_pre_hook(mod, args, kwargs)
- with torch.fx.traceback.preserve_node_meta():
- args = (*args, *kwargs.values())
- tree_out = torch.fx.Interpreter(mod).run(*args)
- else:
- tree_out = mod(*args, **kwargs)
- flat_outs, out_spec = pytree.tree_flatten(tree_out)
- return tuple(flat_outs)
- wrapped_mod = Wrapper(mod)
- # Patch export_root to the signatures so that wrapper module correctly populates the
- # in/out spec
- new_preserved_call_signatures = [
- "_export_root." + i for i in preserve_module_call_signature
- ]
- ctx = nullcontext()
- if not isinstance(mod, torch.fx.GraphModule):
- ctx = _wrap_submodules( # type: ignore[assignment]
- wrapped_mod, new_preserved_call_signatures, module_call_specs
- )
- with ctx:
- gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
- log.debug("Exported program from AOTAutograd:\n%s", gm)
- sig.parameters = pytree.tree_map(_strip_root, sig.parameters)
- sig.buffers = pytree.tree_map(_strip_root, sig.buffers)
- sig.inputs_to_buffers = pytree.tree_map(_strip_root, sig.inputs_to_buffers)
- sig.inputs_to_parameters = pytree.tree_map(
- _strip_root, sig.inputs_to_parameters
- )
- sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate)
- sig.parameters_to_mutate = pytree.tree_map(
- _strip_root, sig.parameters_to_mutate
- )
- for node in gm.graph.nodes:
- if "nn_module_stack" in node.meta:
- nn_module_stack = node.meta["nn_module_stack"]
- node.meta["nn_module_stack"] = {
- _fixup_key(key): val
- for key, val in pytree.tree_map(
- _strip_root, nn_module_stack
- ).items()
- }
- return gm, sig
- return _aot_export_non_strict
- # NOTE: We need to enter _compiling_state_context() here so that FakeTensors
- # created for params/buffers are properly tracked for leak detection.
- # See detect_non_strict_fake_tensor_leaks config.
- # We only enter the context if leak detection is enabled to avoid changing
- # behavior when the config is OFF.
- _fakify_ctx = (
- _compiling_state_context()
- if torch._export.config.detect_non_strict_fake_tensor_leaks
- else nullcontext()
- )
- with _fakify_ctx:
- (
- fake_mode,
- fake_args,
- fake_kwargs,
- equalities_inputs,
- original_signature,
- dynamic_shapes,
- ) = make_fake_inputs(
- mod,
- args,
- kwargs,
- dynamic_shapes,
- prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization
- )
- fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
- def _produce_guards_callback(gm):
- return produce_guards_and_solve_constraints(
- fake_mode=fake_mode,
- gm=gm,
- dynamic_shapes=dynamic_shapes,
- equalities_inputs=equalities_inputs,
- original_signature=original_signature,
- )
- tx = TracingContext(fake_mode)
- # We also need to attach dynamo configs as these will be used in HOOs that
- # use torch.compile, like cond
- dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
- dynamo_config["do_not_emit_runtime_asserts"] = (
- False # We want to emit runtime asserts
- )
- with (
- fake_mode,
- _NonStrictTorchFunctionHandler(),
- tracing(tx),
- torch._dynamo.config.patch(dynamo_config),
- ):
- with (
- _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
- patched_mod,
- new_fake_args,
- new_fake_kwargs,
- new_fake_constant_attrs,
- map_fake_to_real,
- ),
- _fakify_module_inputs(fake_args, fake_kwargs, fake_mode),
- _override_builtin_ops(),
- ):
- # _to_aten_func is _export_to_aten_ir when using the default non-strict export
- # We need to pass positional args correctly
- aten_export_artifact = _to_aten_func(
- patched_mod,
- new_fake_args,
- new_fake_kwargs,
- fake_params_buffers,
- new_fake_constant_attrs,
- produce_guards_callback=_produce_guards_callback,
- transform=_tuplify_outputs,
- )
- # aten_export_artifact.constants contains only fake script objects, we need to map them back
- aten_export_artifact.constants = {
- fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj
- for fqn, obj in aten_export_artifact.constants.items()
- }
- _move_non_persistent_buffers_to_tensor_constants(
- mod, aten_export_artifact.sig, aten_export_artifact.constants
- )
- if out_spec is None:
- raise AssertionError("out_spec must not be None")
- if in_spec is None:
- raise AssertionError("in_spec must not be None")
- return ExportArtifact(
- aten=aten_export_artifact,
- in_spec=in_spec,
- out_spec=out_spec,
- fake_mode=fake_mode,
- module_call_specs=module_call_specs,
- )
- @_log_export_wrapper
- @_disable_prexisiting_fake_mode
- def _export_for_training(
- mod: torch.nn.Module,
- args: tuple[Any, ...],
- kwargs: dict[str, Any] | None = None,
- dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None,
- *,
- strict: bool = True,
- preserve_module_call_signature: tuple[str, ...] = (),
- prefer_deferred_runtime_asserts_over_guards: bool = False,
- ) -> ExportedProgram:
- global _EXPORT_MODULE_HIERARCHY
- _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
- (
- args,
- kwargs,
- orig_in_spec,
- dynamic_shapes,
- verify_additional_inputs,
- ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes)
- original_state_dict = _get_original_state_dict(mod)
- has_ambient_mode = False
- if not strict:
- flat_args, _ = pytree.tree_flatten((args, kwargs))
- has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
- # Call the appropriate export function based on the strictness of tracing.
- export_func = _strict_export if strict else _non_strict_export
- if not strict and torch._export.config.detect_non_strict_fake_tensor_leaks:
- from torch._subclasses.fake_tensor import fake_tensor_tls
- fake_tensor_tls.non_strict_export_fake_tensor_tracker.clear()
- export_artifact = export_func(
- mod=mod,
- args=args,
- kwargs=kwargs,
- dynamic_shapes=dynamic_shapes,
- preserve_module_call_signature=preserve_module_call_signature,
- orig_in_spec=orig_in_spec,
- prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
- _to_aten_func=_export_to_aten_ir_make_fx,
- )
- # If we are tracing with fake inputs, it is expected to
- # see fake tensor constants.
- if not strict and not has_ambient_mode:
- for const, val in export_artifact.aten.constants.items():
- if isinstance(
- val, torch._subclasses.fake_tensor.FakeTensor
- ) and _is_bogus_const_name(const):
- error_msg = (
- f"We found a fake tensor in the exported program constant's list. "
- f"This typically means our tracing system encountered an op that "
- f"we can't trace through. For the potential source, you can refer to "
- f"following model attribute: {const}. "
- f"Please file an issue on github. "
- )
- if torch._export.config.error_on_lifted_constant_tensors:
- raise RuntimeError(error_msg)
- else:
- warnings.warn(error_msg, stacklevel=2)
- export_graph_signature = export_artifact.aten.sig
- forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
- inline_constraints = _get_inline_constraints(export_artifact.fake_mode)
- # The unbacked symint symbols are updated in aot_export
- # so we serialize them here instead of inside dynamo.
- # Note: _get_range_constraints depends on "inline_constraints" to be set.
- export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints
- range_constraints = _get_range_constraints(
- mod,
- export_artifact,
- args,
- kwargs,
- dynamic_shapes,
- )
- # The returned the gm is in-place modified
- gm, module_call_graph = _get_module_call_graph(
- export_artifact,
- preserve_module_call_signature,
- strict,
- forward_arg_names,
- )
- _verify_nn_module_stack(gm)
- _verify_stack_trace(gm)
- _verify_placeholder_names(gm, export_graph_signature)
- _update_gm_meta_if_possible(gm, mod)
- from torch._export.verifier import TrainingIRVerifier
- exported_program = ExportedProgram(
- root=gm,
- graph=gm.graph,
- graph_signature=export_graph_signature,
- state_dict=original_state_dict,
- range_constraints=range_constraints,
- module_call_graph=module_call_graph,
- example_inputs=(args, kwargs),
- constants=export_artifact.aten.constants,
- verifiers=[TrainingIRVerifier],
- )
- verify_additional_inputs(exported_program)
- if not strict and torch._export.config.detect_non_strict_fake_tensor_leaks:
- # See NOTE [export non-strict fake tensor leak detection]
- from torch._subclasses.fake_tensor import fake_tensor_tls
- from torch.fx.experimental.proxy_tensor import (
- _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT,
- )
- active_fakes = fake_tensor_tls.non_strict_export_fake_tensor_tracker
- legit_leak: weakref.WeakSet = find_legit_leaks_from_referrers(active_fakes)
- leak_sources: list[str] = []
- if len(legit_leak) > 0:
- for fake_val in legit_leak:
- if id(fake_val) in _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT:
- node = _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT[id(fake_val)]
- stack_trace = node.meta.get("stack_trace")
- node_name = node.name
- # If no stack trace on this node (e.g., placeholder), look at users
- if stack_trace is None:
- for user in node.users:
- user_stack = user.meta.get("stack_trace")
- if user_stack is not None:
- stack_trace = f"Used by '{user.name}':\n{user_stack}"
- break
- stack_trace = (
- "<no stack trace available>"
- if stack_trace is None
- else stack_trace
- )
- # Get shape and dtype info
- shape_info = f"shape={fake_val.shape}, dtype={fake_val.dtype}"
- leak_info = f"FakeTensor({shape_info}) from node '{node_name}':\n{stack_trace}"
- leak_sources.append(leak_info)
- else:
- # Fallback: no proxy mapping found, show basic info
- shape_info = f"shape={fake_val.shape}, dtype={fake_val.dtype}"
- leak_info = f"FakeTensor({shape_info}): <no proxy mapping found>"
- leak_sources.append(leak_info)
- # Format the warning message more nicely
- leak_details = "\n ".join(leak_sources)
- warnings.warn(
- f"Detected {len(legit_leak)} fake tensors that are still alive after export.\n"
- f"This is likely result of torch.export.export not being able to track side effects "
- f"that is happening outside of model scope.\n\n"
- f"Leaked tensors:\n {leak_details}\n\n"
- f"Alternatively, please file a bug report to PyTorch team for further debugging help.",
- stacklevel=2,
- )
- del legit_leak
- return exported_program
- @_log_export_wrapper
- @_disable_prexisiting_fake_mode
- @compile_time_strobelight_meta(phase_name="export")
- def _export(
- mod: torch.nn.Module,
- args: tuple[Any, ...],
- kwargs: dict[str, Any] | None = None,
- dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None,
- *,
- strict: bool = True,
- preserve_module_call_signature: tuple[str, ...] = (),
- pre_dispatch: bool = False,
- prefer_deferred_runtime_asserts_over_guards: bool = False,
- ) -> ExportedProgram:
- """
- Traces either an nn.Module's forward function or just a callable with PyTorch
- operations inside and produce a ExportedProgram.
- Args:
- mod: the `nn.Module` to trace.
- args: example positional inputs.
- kwargs: optional example keyword inputs.
- dynamic_shapes:
- An optional argument where the type should either be:
- 1) a dict from argument names of ``f`` to their dynamic shape specifications,
- 2) a tuple that specifies dynamic shape specifications for each input in original order.
- If you are specifying dynamism on keyword args, you will need to pass them in the order that
- is defined in the original function signature.
- The dynamic shape of a tensor argument can be specified as either
- (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
- not required to include static dimension indices in this dict, but when they are,
- they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
- where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
- are denoted by None. Arguments that are dicts or tuples / lists of tensors are
- recursively specified by using mappings or sequences of contained specifications.
- preserve_module_call_signature: A list of submodule paths for which the original
- calling conventions are preserved as metadata.
- prefer_deferred_runtime_asserts_over_guards:
- With the current dynamic shapes language for dims and derived dims, we can run into constraints
- that are not expressible with the language. For example, flattening a matrix and adding to a vector,
- both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
- By default, we either raise a constraint violation error or specialize to static values.
- If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime
- assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops
- required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar).
- Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints
- while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes.
- Returns:
- An ExportedProgram containing the traced module.
- """
- from torch._utils_internal import export_training_ir_rollout_check
- global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY
- _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
- flags = set()
- flags.add("strict" if strict else "non_strict")
- flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch")
- _EXPORT_FLAGS = flags
- log_export_usage(event="export.enter", flags=_EXPORT_FLAGS)
- dtrace_structured("export", payload_fn=lambda: "start!")
- # NOTE Export training IR rollout
- # Old export calls export._trace(pre_dispatch=True)
- # and there are still lot of internal/OSS callsites that
- # use export._trace(pre_dispatch=True) directly. Therefore,
- # it makes more sense to do the switch here.
- # export_training_ir_rollout_check returns True in OSS
- # while internally it returns False UNLESS otherwise specified.
- if pre_dispatch and export_training_ir_rollout_check():
- ep = _export_for_training(
- mod,
- args,
- kwargs,
- dynamic_shapes,
- strict=strict,
- preserve_module_call_signature=preserve_module_call_signature,
- prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
- )
- dtrace_structured("exported_program", payload_fn=lambda: str(ep))
- return ep
- (
- args,
- kwargs,
- original_in_spec,
- dynamic_shapes,
- verify_additional_inputs,
- ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes)
- original_state_dict = _get_original_state_dict(mod)
- # Call the appropriate export function based on the strictness of tracing.
- export_func = _strict_export if strict else _non_strict_export
- export_artifact = export_func( # type: ignore[operator]
- mod=mod,
- args=args,
- kwargs=kwargs,
- dynamic_shapes=dynamic_shapes,
- preserve_module_call_signature=preserve_module_call_signature,
- orig_in_spec=original_in_spec,
- prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
- _to_aten_func=functools.partial(
- _export_to_aten_ir,
- pre_dispatch=pre_dispatch,
- ),
- )
- export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
- forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
- inline_constraints = _get_inline_constraints(export_artifact.fake_mode)
- # The unbacked symint symbols are updated in aot_export
- # so we serialize them here instead of inside dynamo.
- # Note: this step must be before _get_range_constraints.
- export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints
- range_constraints = _get_range_constraints(
- mod,
- export_artifact,
- args,
- kwargs,
- dynamic_shapes,
- )
- gm, module_call_graph = _get_module_call_graph(
- export_artifact,
- preserve_module_call_signature,
- strict,
- forward_arg_names,
- )
- _verify_nn_module_stack(gm)
- _verify_stack_trace(gm)
- _verify_placeholder_names(gm, export_graph_signature)
- # Remove Proxy because they cannot be deepcopied or pickled.
- torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True)
- from torch._export.verifier import Verifier
- _update_gm_meta_if_possible(gm, mod)
- exported_program = ExportedProgram(
- root=gm,
- graph=gm.graph,
- graph_signature=export_graph_signature,
- state_dict=original_state_dict,
- range_constraints=range_constraints,
- module_call_graph=module_call_graph,
- example_inputs=(args, kwargs),
- constants=export_artifact.aten.constants,
- verifiers=[Verifier],
- )
- dtrace_structured("exported_program", payload_fn=lambda: str(exported_program))
- verify_additional_inputs(exported_program)
- return exported_program
|