| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899 |
- # mypy: allow-untyped-defs
- import abc
- import copy
- import logging
- import operator
- import re
- from collections import defaultdict
- from collections.abc import Callable
- from contextlib import contextmanager
- from copy import deepcopy
- from dataclasses import dataclass
- from enum import Enum
- from typing import Any, cast
- import torch
- import torch.fx._pytree as fx_pytree
- import torch.utils._pytree as pytree
- from torch._library.fake_class_registry import FakeScriptObject
- from torch.export import ExportedProgram
- from torch.export._tree_utils import reorder_kwargs
- from torch.export.exported_program import (
- ConstantArgument,
- ExportGraphSignature,
- InputKind,
- ModuleCallSignature,
- SymBoolArgument,
- SymFloatArgument,
- SymIntArgument,
- TensorArgument,
- )
- from torch.fx._symbolic_trace import is_fx_symbolic_tracing
- from torch.fx.graph_module import _get_attr, _get_attr_via_attr_list, _print_readable
- from torch.utils._pytree import GetAttrKey, SequenceKey
- from ._remove_effect_tokens_pass import _remove_effect_tokens
- log = logging.getLogger(__name__)
- __all__ = [
- "FlatArgsAdapter",
- "InterpreterModule",
- "InterpreterModuleDispatcher",
- "UnflattenedModule",
- "unflatten",
- ]
- class _AttrKind(Enum):
- PARAMETER = "parameter"
- BUFFER = "buffer"
- CONSTANT = "constant"
- MODULE = "module"
- @dataclass(frozen=True)
- class _TensorID:
- """Custom tensor identifier containing storage, stride, and size information."""
- untyped_storage: torch.UntypedStorage
- stride: tuple
- size: tuple
- storage_offset: int
- RUN_WITH_INTERPRETER = True
- @contextmanager
- def _disable_interpreter():
- global RUN_WITH_INTERPRETER
- old_flag = RUN_WITH_INTERPRETER
- RUN_WITH_INTERPRETER = False
- try:
- yield
- finally:
- RUN_WITH_INTERPRETER = old_flag
- # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
- # This installs empty Modules where none exist yet if they are subpaths of target
- def _assign_attr(
- from_obj: torch.Tensor | torch.ScriptObject | torch.nn.Module,
- to_module: torch.nn.Module,
- target: str,
- attr_kind: _AttrKind,
- persistent: bool = True,
- ):
- *prefix, field = target.split(".")
- # We need to generate all submodules of `to_module` that are at `prefix` and
- # variants of `prefix` that differ only by call name. All of these submodules
- # will then be assigned `from_obj` at `field` so that they can share this attribute.
- # For example, if target is foo.bar.f, foo has another call name foo@1,
- # and bar has other call names bar@1, bar@2, then we will assign f to
- # foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2.
- to_modules = {to_module}
- for item in prefix:
- ts: set[torch.nn.Module] = set()
- for to_module in to_modules:
- if not hasattr(to_module, item):
- setattr(to_module, item, torch.nn.Module())
- ts.update(
- t_call # type: ignore[misc]
- for k, t_call in to_module._modules.items()
- if _is_call_name(k, item)
- )
- to_modules = ts
- for to_module in to_modules:
- if attr_kind == _AttrKind.PARAMETER:
- if not isinstance(from_obj, torch.nn.Parameter):
- raise AssertionError(
- f"expected torch.nn.Parameter for PARAMETER attr_kind, got {type(from_obj)}"
- )
- to_module.register_parameter(field, from_obj)
- elif attr_kind == _AttrKind.BUFFER:
- if not isinstance(from_obj, torch.Tensor):
- raise AssertionError(
- f"expected torch.Tensor for BUFFER attr_kind, got {type(from_obj)}"
- )
- to_module.register_buffer(field, from_obj, persistent=persistent)
- elif attr_kind == _AttrKind.CONSTANT:
- if isinstance(from_obj, FakeScriptObject):
- raise AssertionError(
- "FakeScriptObject should only exist during tracing."
- )
- if not isinstance(
- from_obj,
- (
- torch.Tensor,
- torch.ScriptObject,
- ),
- ):
- raise AssertionError(
- f"expected torch.Tensor or torch.ScriptObject for CONSTANT attr_kind, got {type(from_obj)}"
- )
- setattr(to_module, field, from_obj)
- elif attr_kind == _AttrKind.MODULE:
- if not isinstance(from_obj, torch.nn.Module):
- raise AssertionError(
- f"expected torch.nn.Module for MODULE attr_kind, got {type(from_obj)}"
- )
- setattr(to_module, field, from_obj)
- class _SubmoduleBase:
- _ty: str | None
- def type_name(self) -> str | None:
- """
- Subclass of this class - InterpreterModule, InterpreterModuleDispatcher, represents
- corresponding model in eager model. To get this type information for those modules
- in eager model we need to use this method.
- """
- return self._ty
- class InterpreterModule(_SubmoduleBase, torch.nn.Module):
- """A module that uses torch.fx.Interpreter to execute instead of the usual
- codegen that GraphModule uses. This provides better stack trace information
- and makes it easier to debug execution.
- """
- graph_module: torch.fx.GraphModule | None
- def __init__(
- self,
- graph: torch.fx.Graph,
- ty: str | None = None,
- ):
- super().__init__()
- self.graph = graph
- self._ty = ty
- self.graph.owning_module = self # type: ignore[assignment]
- self._run_with_interpreter = RUN_WITH_INTERPRETER
- def forward(self, *args, **kwargs):
- if self.graph_module is None:
- raise AssertionError("Didn't finalize this InterpreterModule")
- if not is_fx_symbolic_tracing() and (
- torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter
- ):
- # Dynamo cannot trace through torch.fx.Interpreter, so fall back to
- # GraphModule codegen in this instance.
- # Patch the codegened forward to run with this InterpreterModule,
- # so attribute accesses, etc. are on this module instead.
- return type(self.graph_module).forward(self, *args, **kwargs)
- else:
- if kwargs:
- # Handle **kwargs. FX only natively supports positional
- # arguments (through placeholders). So in order to pass in
- # kwargs, we must correspond the names of the placeholders with
- # the keys in the kwarg dict.
- arg_list = list(args)
- kwarg_names = self.arg_names[len(arg_list) :]
- arg_list.extend(
- kwargs[kwarg_name]
- for kwarg_name in kwarg_names
- if kwarg_name in kwargs
- )
- # Assert that the kwargs passed in exactly match the positional
- # arguments specified by the GraphModule. This should be
- # guaranteed by the unflattening process.
- if len(kwarg_names) != len(kwargs):
- raise AssertionError(
- f"kwarg_names length {len(kwarg_names)} does not match kwargs length {len(kwargs)}"
- )
- if len(arg_list) != len(self.arg_names):
- raise AssertionError(
- f"arg_list length {len(arg_list)} does not match arg_names length {len(self.arg_names)}"
- )
- args = tuple(arg_list)
- return torch.fx.Interpreter(self, graph=self.graph).run(
- *args, enable_io_processing=False
- )
- def finalize(self):
- # We need to "finalize" because GraphModule populates its own state_dict
- # based on the get_attrs observed in the graph. So we need to fully
- # construct the graph and call _sink_params before generating this
- # GraphModule.
- # need to set `graph_module` directly on the dict to avoid it getting
- # registered as a submodule.
- self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
- self.graph.lint()
- # Cache arg names for kwarg handling (see forward())
- self.arg_names = []
- for node in self.graph.nodes:
- if node.op == "placeholder":
- self.arg_names.append(node.target)
- def print_readable(
- self,
- print_output=True,
- include_stride=False,
- include_device=False,
- colored=False,
- ):
- return _print_readable(
- self,
- "InterpreterModule",
- print_output,
- include_stride,
- include_device,
- colored,
- )
- class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module):
- """
- A module that carries a sequence of InterpreterModules corresponding to
- a sequence of calls of that module. Each call to the module dispatches
- to the next InterpreterModule, and wraps back around after the last.
- """
- def __init__(self, attrs: set[str], call_modules: list[InterpreterModule]):
- super().__init__()
- if not call_modules:
- raise AssertionError("call_modules must not be empty")
- self._modules = call_modules[0]._modules
- for accessor in attrs:
- setattr(self, accessor, getattr(call_modules[0], accessor))
- self._ty = call_modules[0]._ty
- self._call_modules = call_modules
- self._num_calls = 0
- def forward(self, *args, **kwargs):
- call_module = self._call_modules[self._num_calls]
- self._num_calls = (self._num_calls + 1) % len(self._call_modules)
- try:
- return call_module(*args, **kwargs)
- except Exception:
- self._num_calls = 0
- raise
- def call_modules(self):
- return self._call_modules
- def print_readable(
- self,
- print_output=True,
- include_stride=False,
- include_device=False,
- colored=False,
- ):
- outputs = [
- mod.print_readable(
- print_output,
- include_stride,
- include_device,
- colored,
- )
- for mod in self._call_modules
- ]
- return "\n".join(outputs)
- class FlatArgsAdapter(abc.ABC):
- """
- Adapts input arguments with ``input_spec`` to align ``target_spec``.
- """
- @abc.abstractmethod
- def adapt(
- self,
- target_spec: pytree.TreeSpec,
- input_spec: pytree.TreeSpec,
- input_args: list[Any],
- metadata: dict[str, Any] | None = None,
- obj: Any | None = None,
- ) -> list[Any]:
- """NOTE: This adapter may mutate given ``input_args_with_path``."""
- ...
- def get_flat_arg_paths(self) -> list[str]:
- """Returns a list of paths that are used to access the flat args."""
- return []
- class UnflattenedModule(_SubmoduleBase, torch.nn.Module):
- def __init__(
- self,
- export_module: ExportedProgram,
- flat_args_adapter: FlatArgsAdapter | None = None,
- ):
- super().__init__()
- if export_module.graph_signature.backward_signature is not None:
- raise ValueError("Unflattening on JointExportModule NYI")
- def _id(obj):
- """Returns _TensorID dataclass for tensors, otherwise id()."""
- if isinstance(obj, torch.Tensor):
- return _TensorID(
- untyped_storage=obj.untyped_storage(),
- stride=obj.stride(),
- size=obj.size(),
- storage_offset=obj.storage_offset(), # type: ignore[arg-type]
- )
- return id(obj)
- fqn_list = [entry.fqn for entry in export_module.module_call_graph]
- if fqn_list[0] != "":
- raise AssertionError(
- f"expected first fqn to be empty string, got {fqn_list[0]!r}"
- )
- export_graph = deepcopy(export_module.graph)
- self.graph_signature = deepcopy(export_module.graph_signature)
- self.graph = torch.fx.Graph()
- self.graph.owning_module = self # type: ignore[assignment]
- self.module_call_graph = deepcopy(export_module.module_call_graph)
- self.flat_args_adapter = flat_args_adapter
- self.meta = export_module.graph_module.meta
- self.meta["unflattened_module"] = self
- # Flag to indicate whether args have been adapted.
- self.adapted = False
- self._run_with_interpreter = RUN_WITH_INTERPRETER
- _inplace_buffer_and_input_mutations(export_graph, self.graph_signature)
- _fix_nn_module_stacks(export_graph)
- self._ty = _root_module_type(export_graph)
- self.ivals = _IVals()
- # for any intermediate value of a mutation that is read, track the mutation
- seen_modules, seen_attrs = _outline_submodules(export_graph, self)
- # for each read intermediate value of a mutation, find where it was created,
- # and perform the mutation
- self.ivals.update(seen_modules.values())
- # move attributes that correspond to graph arguments for HOPs
- # from exported program to unflattened submodules
- _copy_graph_attrs(export_module._graph_module, self, seen_attrs)
- self.range_constraints = export_module.range_constraints
- self.equality_constraints: list = []
- # aliasing/unused param or buffer issues:
- # in strict-mode export, dynamo export will deduplicate aliased tensors,
- # and ignore unused tensors. For aliasing, this causes issues when some aliases
- # are unused, and we're unable to match the placeholder node to the correct FQN.
- # This leads to the graph signature potentially having the wrong target FQN,
- # and downstream issues where parameters are assigned to the wrong target attribute,
- # mismatching the relevant placeholder node in the unflattened module.
- # To resolve this we restore (_assign_attr) all aliased/unused tensors in
- # the state_dict as module attributes, but only keep the used tensors in the
- # graph's forward pass (_sink_params).
- state_dict = export_module.state_dict
- assigned_params: set[str] = set() # tracking unused params
- id_to_param: dict[
- int | _TensorID, torch.nn.Parameter
- ] = {} # handling weight-sharing
- for name in self.graph_signature.parameters: # this loop adds used params
- param = state_dict[name]
- if _id(param) not in id_to_param:
- id_to_param[_id(param)] = torch.nn.Parameter(
- param.clone(), requires_grad=param.requires_grad
- )
- _assign_attr(
- id_to_param[_id(param)],
- self,
- name,
- attr_kind=_AttrKind.PARAMETER,
- )
- assigned_params.add(name)
- non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
- assigned_buffers: set[str] = set() # tracking unused buffers
- id_to_buffer: dict[int | _TensorID, tuple[torch.nn.Parameter, bool]] = {}
- for name in self.graph_signature.buffers: # this loop adds used buffers
- if name in non_persistent_buffers:
- persistent = False
- buffer = export_module.constants[name]
- else:
- persistent = True
- buffer = state_dict[name]
- if _id(buffer) not in id_to_buffer:
- id_to_buffer[_id(buffer)] = (buffer.clone(), persistent)
- _assign_attr(
- id_to_buffer[_id(buffer)][0],
- self,
- name,
- attr_kind=_AttrKind.BUFFER,
- persistent=persistent,
- )
- assigned_buffers.add(name)
- # restore aliased/unused params and buffers
- # these appear in state dict but not graph signature
- for name, tensor in state_dict.items():
- if name in assigned_params or name in assigned_buffers: # already assigned
- continue
- is_buffer = False
- if _id(tensor) in id_to_buffer or not isinstance(
- tensor, torch.nn.Parameter
- ): # aliased buffer
- is_buffer = True
- if is_buffer:
- if (
- _id(tensor) not in id_to_buffer
- ): # this is completely unused (not weight-sharing)
- id_to_buffer[_id(tensor)] = (
- tensor,
- True,
- ) # assign to respect original model
- _assign_attr(
- id_to_buffer[_id(tensor)][0],
- self,
- name,
- attr_kind=_AttrKind.BUFFER,
- persistent=True,
- )
- else:
- if _id(tensor) not in id_to_param: # this is unused
- id_to_param[_id(tensor)] = tensor
- _assign_attr(
- id_to_param[_id(tensor)],
- self,
- name,
- attr_kind=_AttrKind.PARAMETER,
- )
- # use id map so we don't double-clone aliased constants
- id_to_const: dict[int | _TensorID, torch.Tensor | torch._C.ScriptObject] = {}
- for fqn, constant in export_module.constants.items():
- if _id(constant) not in id_to_const:
- if isinstance(constant, torch.Tensor):
- constant = constant.clone()
- id_to_const[_id(constant)] = constant
- _constant = id_to_const[_id(constant)]
- _assign_attr(
- _constant,
- self,
- fqn,
- attr_kind=_AttrKind.CONSTANT,
- )
- # This is to handle parameters/buffers that point to the same tensor
- # object id -> list of (node_name, target_name)
- consts_map: dict[int | _TensorID, list[tuple[str, str]]] = defaultdict(list)
- consts_targets: set[str] = set()
- def add_to_consts_map(obj_id, node_name, target_name):
- name_list = consts_map[obj_id]
- name_list.append((node_name, target_name))
- # track aliased/unused params, buffers
- # prefer using untyped_storage() over id() when it's available
- added_params_buffers: set[str] = set()
- for s in self.graph_signature.input_specs:
- if s.kind == InputKind.PARAMETER or (
- s.kind == InputKind.BUFFER and s.persistent
- ):
- if not hasattr(s.arg, "name"):
- raise AssertionError(
- f"expected s.arg to have 'name' attribute, got {type(s.arg)}"
- )
- if not isinstance(s.target, str):
- raise AssertionError(
- f"expected s.target to be str, got {type(s.target)}"
- )
- add_to_consts_map(
- _id(export_module.state_dict[s.target]),
- s.arg.name,
- s.target,
- )
- consts_targets.add(s.target)
- added_params_buffers.add(s.target)
- elif (
- s.kind == InputKind.BUFFER
- and not s.persistent
- or s.kind == InputKind.CONSTANT_TENSOR
- or s.kind == InputKind.CUSTOM_OBJ
- ):
- if not hasattr(s.arg, "name"):
- raise AssertionError(
- f"expected s.arg to have 'name' attribute for kind {s.kind}, got {type(s.arg)}"
- )
- if not isinstance(s.target, str):
- raise AssertionError(
- f"expected s.target to be str for kind {s.kind}, got {type(s.target)}"
- )
- add_to_consts_map(
- _id(export_module.constants[s.target]),
- s.arg.name,
- s.target,
- )
- consts_targets.add(s.target)
- # add constants that are aliased and don't appear in graph signature
- for const_name, const in export_module.constants.items():
- if const_name not in consts_targets:
- const_id = _id(const)
- if const_id not in consts_map:
- raise AssertionError(
- f"constant {const_name!r} id not found in consts_map"
- )
- ph_name, _ = consts_map[const_id][0]
- add_to_consts_map(const_id, ph_name, const_name)
- added_params_buffers.add(s.target)
- # add aliased/unused params and buffers that don't appear in graph signature
- for fqn, tensor in export_module.state_dict.items():
- if fqn not in added_params_buffers:
- tensor_id = _id(tensor)
- if tensor_id not in consts_map:
- # completely unused (no weight-sharing), ignore.
- # this weight doesn't appear in graph module,
- # so won't cause FQN assignment issues
- continue
- ph_name, _ = consts_map[tensor_id][0]
- add_to_consts_map(tensor_id, ph_name, fqn)
- # node name -> list of possible targets
- inputs_to_state: dict[str, list[str]] = {}
- for node_target in consts_map.values():
- targets = [t[1] for t in node_target]
- for n, _ in node_target:
- inputs_to_state[n] = targets
- _sink_params(self, inputs_to_state, [])
- redirected_call_indices = _deduplicate_modules(seen_modules.values())
- fqn_list = [fqn for fqn in fqn_list if fqn not in redirected_call_indices]
- self._dispatch_modules(redirected_call_indices, consts_targets)
- fqn_list = [fqn for fqn in fqn_list if "@" not in fqn]
- # Cache so we don't have to compute this every time.
- # NOTE: this needs to be kept in sync with the placeholders in
- # self.graph, but currently we have no way to guarantee that.
- self.input_placeholders = [
- node for node in self.graph.nodes if node.op == "placeholder"
- ]
- self.check_input_constraints = True
- # TODO(zhxchen17) We can register modules ahead of time instead of reorder later.
- fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)}
- # In the case of legacy IR, we might be missing some modules from metadata.
- for name, _ in self.named_modules(remove_duplicate=False):
- if name not in fqn_order:
- fqn_order[name] = len(fqn_order)
- _reorder_submodules(self, fqn_order)
- self.graph.lint()
- self.finalize()
- def _print_graph(self):
- for fqn, mod in self.named_modules():
- print(fqn + ":")
- if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph):
- print(mod.graph)
- def _adapt_flat_args(self, flat_args, in_spec, input):
- signature = self.module_call_graph[0].signature
- if in_spec == signature.in_spec:
- return flat_args
- if self.flat_args_adapter is None:
- raise TypeError(
- "There is no flat args adapter specified. "
- "Are you sure you are calling this with the right arguments? "
- )
- else:
- flat_args = self.flat_args_adapter.adapt(
- target_spec=signature.in_spec,
- input_spec=in_spec,
- input_args=flat_args,
- metadata=self.meta,
- obj=input,
- )
- if len(flat_args) != signature.in_spec.num_leaves:
- raise TypeError(
- f"Flat args adaption failed, number of args mismatch "
- f"Adatped: {len(flat_args)} \n"
- f"Exported module: {signature.in_spec.num_leaves}"
- )
- return flat_args
- def process_forward_inputs(self, *args, **kwargs):
- signature = self.module_call_graph[0].signature
- reordered_kwargs = kwargs
- if kwargs:
- reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec)
- flat_args_with_path, in_spec = pytree.tree_flatten_with_path(
- (args, reordered_kwargs)
- )
- flat_args = [x[1] for x in flat_args_with_path]
- if is_fx_symbolic_tracing():
- return flat_args
- if in_spec != signature.in_spec:
- if not self.adapted:
- print(
- "Input treespec does not match with exported module's: \n"
- f"Input treespec: {in_spec}. ",
- f"Exported module treespec: {signature.in_spec}",
- )
- print("Adapting flat arg to match exported module's treespec")
- flat_args = self._adapt_flat_args(flat_args, in_spec, args)
- self.adapted = True
- if self.check_input_constraints:
- # Import here to avoid an unfortunate circular dependency.
- # TODO(suo): untangle this.
- from torch._export.utils import _check_input_constraints_for_graph
- if self.adapted is True:
- flat_arg_paths = (
- self.flat_args_adapter.get_flat_arg_paths()
- if self.flat_args_adapter
- else []
- )
- if flat_arg_paths and len(flat_arg_paths) != len(flat_args):
- raise AssertionError(
- f"flat_arg_paths length {len(flat_arg_paths)} does not match flat_args length {len(flat_args)}"
- )
- new_flat_args_with_path = [ # type: ignore[var-annotated]
- (
- (
- SequenceKey(idx=idx),
- GetAttrKey(
- name=flat_arg_paths[idx]
- if flat_arg_paths
- else "<unknown location>"
- ),
- ),
- arg,
- )
- for idx, arg in enumerate(flat_args)
- ]
- else:
- new_flat_args_with_path = flat_args_with_path # type: ignore[assignment]
- _check_input_constraints_for_graph(
- self.input_placeholders, new_flat_args_with_path, self.range_constraints
- )
- return flat_args
- def forward(self, *args, **kwargs):
- flat_args = self.process_forward_inputs(*args, **kwargs)
- signature = self.module_call_graph[0].signature
- if is_fx_symbolic_tracing():
- return_val = torch.fx.Interpreter(self, graph=self.graph).run(
- *flat_args, enable_io_processing=False
- )
- # For scalar return value, fx.Graph wraps in a tuple
- if isinstance(return_val, tuple) and len(return_val) == 1:
- return return_val[0]
- return return_val
- if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter:
- tree_out = type(self.graph_module).forward(self, *flat_args) # type: ignore[union-attr]
- else:
- tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
- *flat_args, enable_io_processing=False
- )
- return pytree.tree_unflatten(tree_out, signature.out_spec)
- def finalize(self):
- self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
- self.graph.lint()
- def _dispatch_modules(self, redirected_call_indices, consts_targets):
- """For a module whose call signatures are preserved, replace
- multiple modules corresponding to multiple calls to that module
- with a single dispatcher module that tracks which module to call.
- """
- # for each fqn whose module call signature is preserved,
- # map that fqn to a list of called modules
- called_modules = defaultdict(list)
- for entry in self.module_call_graph:
- if entry.fqn and entry.signature:
- # some modules were removed and their fqns redirected to other
- # fqns during deduplication
- fqn = entry.fqn
- mod = _get_attr(self, redirected_call_indices.get(fqn, fqn))
- base, idx = fqn.split("@") if "@" in fqn else [fqn, "0"]
- called_modules[base].append((int(idx), mod))
- attrs_map = defaultdict(set)
- for target in consts_targets:
- if "." in target:
- orig_fqn, name = target.rsplit(".", 1)
- attrs_map[orig_fqn].add(name)
- else:
- attrs_map[""].add(target)
- # replace multiple call modules with a single dispatcher module
- for orig_fqn, indexed_call_modules in called_modules.items():
- call_modules = [mod for _, mod in sorted(indexed_call_modules)]
- if len(call_modules) > 1:
- for i in range(len(call_modules)):
- fqn = _call_name(orig_fqn, i + 1)
- if fqn not in redirected_call_indices:
- *prefix, name = fqn.split(".")
- _get_attr_via_attr_list(self, prefix)._modules.pop(name)
- self.set_submodule(
- orig_fqn,
- InterpreterModuleDispatcher(attrs_map[orig_fqn], call_modules),
- )
- # elide call indices in call modules because they are
- # tracked automatically inside the dispatcher module
- def elide_call_indices(prefix, graph):
- for node in graph.nodes:
- if node.op == "call_module":
- fqn = node.target.split("@")[0]
- path = f"{prefix}.{fqn}" if prefix else fqn
- if path in called_modules:
- node.target = fqn
- for fqn, mod in self.named_modules(remove_duplicate=False):
- if hasattr(mod, "graph"):
- elide_call_indices(fqn, mod.graph)
- elif hasattr(mod, "_call_modules"):
- for mod_ in mod._call_modules:
- if not hasattr(mod_, "graph"):
- raise AssertionError(
- f"expected mod_ to have 'graph' attribute, got {type(mod_)}"
- )
- elide_call_indices(fqn, mod_.graph)
- def print_readable(
- self,
- print_output=True,
- include_stride=False,
- include_device=False,
- colored=False,
- ):
- return _print_readable(
- self,
- "UnflattenedModule",
- print_output,
- include_stride,
- include_device,
- colored,
- )
- def unflatten(
- module: ExportedProgram, flat_args_adapter: FlatArgsAdapter | None = None
- ) -> UnflattenedModule:
- """Unflatten an ExportedProgram, producing a module with the same module
- hierarchy as the original eager module. This can be useful if you are trying
- to use :mod:`torch.export` with another system that expects a module
- hierarchy instead of the flat graph that :mod:`torch.export` usually produces.
- .. note:: The args/kwargs of unflattened modules will not necessarily match
- the eager module, so doing a module swap (e.g. :code:`self.submod =
- new_mod`) will not necessarily work. If you need to swap a module out, you
- need to set the :code:`preserve_module_call_signature` parameter of
- :func:`torch.export.export`.
- Args:
- module (ExportedProgram): The ExportedProgram to unflatten.
- flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's.
- Returns:
- An instance of :class:`UnflattenedModule`, which has the same module
- hierarchy as the original eager module pre-export.
- """
- module = _remove_effect_tokens(module)
- m = UnflattenedModule(module, flat_args_adapter)
- # Disable process_forward_inputs as the adapter has many
- # non-dynamo-traceable behavior.
- m.process_forward_inputs = torch._dynamo.disable( # type: ignore[method-assign]
- m.process_forward_inputs,
- reason="do not trace into preprocessing the inputs",
- recursive=True,
- )
- return m
- def _inplace_buffer_and_input_mutations(
- graph: torch.fx.Graph,
- graph_signature: ExportGraphSignature,
- ) -> None:
- """Transform buffer and input mutations from their functionalized form
- into copy_ nodes in the graph.
- Functionalization represents a buffer mutation by passing the buffer as
- an input and output. For example, consider the eager code:
- def forward(self, x):
- self.buffer += x
- return x * x
- This corresponds to a graph that looks like:
- def forward(self, buffer, x):
- mutated_buffer = aten.add(buffer, x)
- mul = aten.mul(x, x)
- return (mutated_buffer, mul)
- We want to inplace this into something that looks like the original
- eager code:
- def forward(self, buffer, x):
- mutated_buffer = aten.add(buffer, x)
- buffer.copy_(mutated_buffer)
- mul = aten.mul(x, x)
- return (mul,)
- Input mutations are handled similarly.
- """
- output_node = next(iter(reversed(graph.nodes)))
- if output_node.op != "output" or len(output_node.args) != 1:
- raise AssertionError(
- f"expected output node with op='output' and 1 arg, got op={output_node.op!r} with {len(output_node.args)} args"
- )
- return_args = output_node.args[0]
- input_name_to_node = {
- node.name: node for node in graph.nodes if node.op == "placeholder"
- }
- mutation_name_to_input_name = {}
- # Collect mutated buffers.
- buffer_fqn_to_input_name = {
- buffer_fqn: k for k, buffer_fqn in graph_signature.inputs_to_buffers.items()
- }
- mutation_name_to_input_name = {
- k: buffer_fqn_to_input_name[buffer_fqn]
- for k, buffer_fqn in graph_signature.buffers_to_mutate.items()
- }
- # Collect mutated user inputs.
- mutation_name_to_input_name.update(graph_signature.user_inputs_to_mutate)
- num_mutations = len(mutation_name_to_input_name)
- for mutation in return_args[:num_mutations]:
- input_name = mutation_name_to_input_name[mutation.name]
- input_node = input_name_to_node[input_name]
- with graph.inserting_after(mutation):
- # Create a copy_ node that inplaces the mutation.
- new_node = graph.create_node(
- "call_function", torch.ops.aten.copy_.default, (input_node, mutation)
- )
- for k, v in mutation.meta.items():
- new_node.meta[k] = v
- # Replace all uses of the previously functional mutation with
- # our copy_ node.
- mutation.replace_all_uses_with(new_node, lambda x: x is not new_node)
- # Remove the mutated buffer / input from the graph outputs, since we don't
- # need to thread it through anymore.
- user_outputs = tuple(return_args[num_mutations:])
- output_node.args = ((user_outputs),)
- def _root_module_type(graph: torch.fx.Graph) -> str | None:
- for node in graph.nodes:
- if "nn_module_stack" not in node.meta:
- continue
- for path, ty in node.meta["nn_module_stack"].values():
- if not path:
- return ty
- return None
- def _fix_nn_module_stacks(graph):
- # For each nn module stack in the graph, check if the fqns in it represent a stack:
- # 1. Each fqn must be a prefix of the next fqn.
- # 2. If not, remove the entries starting from the next fqn, emitting a warning.
- for node in graph.nodes:
- if "nn_module_stack" not in node.meta:
- continue
- nn_module_stack = node.meta["nn_module_stack"]
- fqns = [
- fqn.split("@")[0] if "@" in fqn else fqn
- for fqn, _t in nn_module_stack.values()
- ]
- # Check if each FQN is a prefix of the next one
- prev_fqn, *next_fqns = fqns
- num_valid_indices = 1 # root FQN
- for curr_fqn in next_fqns:
- # Check if the previous FQN is a prefix of the current one
- if _is_prefix(prev_fqn, curr_fqn):
- num_valid_indices += 1
- prev_fqn = curr_fqn
- else:
- # Found a non-prefix FQN, stop here
- break
- # If we need to remove entries, create a new stack with only valid entries
- if num_valid_indices < len(nn_module_stack):
- log.warning(
- "nn_module_stack fqns %s at node %s do not form a stack! dropping last %d entries",
- fqns,
- node,
- len(nn_module_stack) - num_valid_indices,
- )
- node.meta["nn_module_stack"] = dict(
- list(nn_module_stack.items())[:num_valid_indices]
- )
- def _is_prefix(candidate, target):
- """Check whether `candidate` is a prefix of `target`."""
- return len(candidate) < len(target) and target[: len(candidate)] == candidate
- def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
- if parent_fqn == "":
- # Handle the root module correctly.
- return child_fqn
- parent_split = parent_fqn.split(".")
- child_split = child_fqn.split(".")
- # TODO: support skip connection by inlining the child module.
- if child_split[: len(parent_split)] != parent_split:
- raise RuntimeError(
- f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'."
- "This is currently unsupported."
- "Please try to make child module attach to parent module directly."
- )
- return ".".join(child_split[len(parent_split) :])
- def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
- def graph_dump(graph: torch.fx.Graph) -> str:
- ret = []
- nodes_idx: dict[int, int] = {}
- def arg_dump(arg) -> str:
- if isinstance(arg, torch.fx.Node):
- return "%" + str(nodes_idx[id(arg)])
- return str(arg)
- for i, node in enumerate(graph.nodes):
- args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)]
- args_dump += [
- f"{key}={value}"
- for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
- ]
- target = node.target if node.op in ("call_function", "get_attr") else ""
- # pyrefly: ignore [bad-argument-type]
- ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
- nodes_idx[id(node)] = i
- return "\n".join(ret)
- if not isinstance(x.graph, torch.fx.Graph):
- raise AssertionError(
- f"expected x.graph to be torch.fx.Graph, got {type(x.graph)}"
- )
- if not isinstance(y.graph, torch.fx.Graph):
- raise AssertionError(
- f"expected y.graph to be torch.fx.Graph, got {type(y.graph)}"
- )
- return graph_dump(x.graph) == graph_dump(y.graph)
- def _add_spec(gm: torch.nn.Module, spec) -> str:
- i = 0
- while hasattr(gm, f"_spec_{i}"):
- i += 1
- name = f"_spec_{i}"
- setattr(gm, name, spec)
- return name
- def _generate_flatten(gm: torch.fx.GraphModule, node) -> torch.fx.Node:
- flatten = gm.graph.call_function(pytree.tree_flatten, (node,))
- getitem_0 = gm.graph.call_function(operator.getitem, (flatten, 0))
- return getitem_0
- def _generate_flatten_spec(
- gm: torch.fx.GraphModule | InterpreterModule | UnflattenedModule, node, spec
- ) -> torch.fx.Node:
- name = _add_spec(gm, spec)
- spec_node = gm.graph.get_attr(name)
- return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node))
- def _generate_unflatten(
- gm: torch.fx.GraphModule | InterpreterModule | UnflattenedModule, nodes, spec
- ) -> torch.fx.Node:
- name = _add_spec(gm, spec)
- spec_node = gm.graph.get_attr(name)
- return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node))
- def _get_submodule(mod: torch.nn.Module, target: str):
- *prefix, field = target.split(".")
- for item in prefix:
- submod = getattr(mod, item, None)
- if submod is None:
- return None
- if not isinstance(submod, torch.nn.Module):
- return None
- mod = submod
- return getattr(mod, field, None)
- def _add_submodule(
- mod: torch.nn.Module,
- target: str,
- module_to_add: torch.nn.Module,
- create_module: Callable[[str], torch.nn.Module] | None = None,
- ):
- *prefix, field = target.split(".")
- for i, item in enumerate(prefix):
- submod = getattr(mod, item, None)
- if submod is None:
- if create_module is not None:
- submod = create_module(".".join(prefix[: i + 1]))
- else:
- submod = torch.nn.Module()
- setattr(mod, item, submod)
- if not isinstance(submod, torch.nn.Module):
- return False
- mod = submod
- mod.add_module(field, module_to_add)
- def _call_name(base: str, n: int) -> str:
- # Given n >= 0, generate call names to a submodule `base` of the form
- # `base`, `base@1`, `base@2`, etc.
- return base if n == 1 else f"{base}@{n - 1}"
- def _is_call_name(call_name: str, base: str) -> bool:
- # Recognize when call_name = _call_name(base, n) for some n >= 0.
- return re.match(re.escape(base) + r"(@\d+)?$", call_name) is not None
- class _ModuleFrame:
- def __init__(
- self,
- flat_graph: torch.fx.Graph,
- nodes: tuple[torch.fx.Node, ...],
- seen_nodes,
- seen_modules,
- seen_attrs,
- created_modules,
- parent,
- module_stack: list[tuple[str, str | None, int]],
- module_id,
- module_call_graph: dict[str, ModuleCallSignature],
- module: torch.fx.GraphModule | UnflattenedModule | None = None,
- ):
- self.flat_graph = flat_graph
- self.nodes = nodes
- self.seen_nodes = seen_nodes
- self.seen_modules = seen_modules
- self.seen_attrs = seen_attrs
- self.created_modules = created_modules
- self.parent = parent
- self.module_stack = module_stack
- self.module_id = module_id
- self.module_call_graph = module_call_graph
- self.verbose = False
- self.fqn, ty, num_calls = self.module_stack[-1]
- # generate call name for self.fqn
- self.child_fqn = _call_name(self.fqn, num_calls + 1)
- self.module: torch.fx.GraphModule | UnflattenedModule | InterpreterModule
- if module is not None:
- self.module = module
- self.ivals = module.ivals if hasattr(module, "ivals") else {} # type: ignore[var-annotated]
- else:
- self.module = self.created_modules.get(
- self.fqn,
- InterpreterModule(torch.fx.Graph(), ty=ty),
- )
- self.ivals = parent.ivals
- self.graph = self.module.graph
- # Mapping of nodes in the flat graph to nodes in this graph.
- self.node_map: dict[torch.fx.Node, torch.fx.Node] = {}
- self.node_to_placeholder = {}
- self.parent_call_module: torch.fx.Node | None = None
- if parent is not None:
- accessor = _compute_accessor(parent.fqn, self.child_fqn)
- def create_module(fqn):
- path = f"{parent.fqn}.{fqn}" if parent.fqn else fqn
- if path in self.created_modules:
- return self.created_modules[path]
- submod = InterpreterModule(torch.fx.Graph(), ty=ty)
- self.created_modules[path] = submod
- return submod
- _add_submodule(parent.module, accessor, self.module, create_module)
- self.parent_call_module = parent.graph.call_module(accessor)
- if self.seen_modules[self.module_id]:
- base_module_frame = self.seen_modules[self.module_id][0]
- self.module._modules = base_module_frame.module._modules
- self.seen_modules[self.module_id].append(
- _SubmoduleEntry(
- parent_fqn=self.parent.fqn,
- parent_module=self.parent.module,
- parent_call_module=self.parent_call_module,
- fqn=self.fqn,
- call_idx=num_calls + 1,
- module=self.module,
- )
- )
- signature = module_call_graph.get(self.child_fqn)
- if signature is not None and self.parent is not None:
- if signature.in_spec.num_children != 2:
- raise AssertionError(
- f"expected in_spec to have 2 children, got {signature.in_spec.num_children}"
- )
- if signature.in_spec.type is not tuple:
- raise AssertionError(
- f"expected in_spec.type to be tuple, got {signature.in_spec.type}"
- )
- args_spec, kwargs_spec = signature.in_spec.children()
- if args_spec.type is not tuple:
- raise AssertionError(
- f"expected args_spec.type to be tuple, got {args_spec.type}"
- )
- if kwargs_spec.type is not dict:
- raise AssertionError(
- f"expected kwargs_spec.type to be dict, got {kwargs_spec.type}"
- )
- with self.graph.inserting_after(None):
- arg_nodes = [
- self.graph.placeholder(f"_positional_arg_{idx}")
- for idx in range(args_spec.num_children)
- ]
- kwarg_nodes = {}
- for name in kwargs_spec.context:
- kwarg_nodes[name] = self.graph.placeholder(name)
- flat_args = _generate_flatten_spec(
- self.module,
- (tuple(arg_nodes), kwarg_nodes),
- signature.in_spec,
- )
- for idx, arg in enumerate(signature.inputs):
- flat_arg_node = self.graph.create_node(
- op="call_function",
- target=operator.getitem,
- args=(flat_args, idx),
- name=(
- arg.name
- if not isinstance(arg, ConstantArgument)
- else f"_constant_{idx}"
- ),
- )
- if isinstance(arg, ConstantArgument):
- continue
- if arg.name in self.seen_nodes:
- flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
- self.node_to_placeholder[self.seen_nodes[arg.name]] = (
- flat_arg_node
- )
- with self.parent.graph.inserting_before(self.parent_call_module):
- input_nodes: list[torch.fx.Node | None] = []
- for input in signature.inputs:
- if isinstance(input, ConstantArgument):
- input_nodes.append(input.value) # type: ignore[arg-type]
- elif input.name not in self.seen_nodes:
- input_nodes.append(None)
- else:
- if not isinstance(
- input,
- (
- TensorArgument,
- SymIntArgument,
- SymBoolArgument,
- SymFloatArgument,
- ),
- ):
- raise AssertionError(
- f"expected input to be TensorArgument, SymIntArgument, "
- f"SymBoolArgument, or SymFloatArgument, got {type(input)}"
- )
- input_nodes.append(
- self.parent.remap_input(self.seen_nodes[input.name])
- )
- inputs_node = _generate_unflatten(
- self.parent.module,
- input_nodes,
- signature.in_spec,
- )
- args_node = self.parent.graph.call_function(
- operator.getitem, (inputs_node, 0)
- )
- kwargs_node = self.parent.graph.call_function(
- operator.getitem, (inputs_node, 1)
- )
- arg_nodes = [
- self.parent.graph.call_function(operator.getitem, (args_node, i))
- for i in range(args_spec.num_children)
- ]
- kwarg_nodes = {
- k: self.parent.graph.call_function(
- operator.getitem, (kwargs_node, k)
- )
- for k in kwargs_spec.context
- }
- if self.parent_call_module is None:
- raise AssertionError("parent_call_module must not be None")
- # pyrefly: ignore [bad-assignment]
- self.parent_call_module.args = tuple(arg_nodes)
- self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment]
- def add_placeholder(self, x):
- if self.fqn == "":
- raise AssertionError(f"Cannot add placeholder {x} to root module")
- if x.graph is not self.flat_graph:
- raise AssertionError(
- "expected x.graph to be flat_graph, got different graph"
- ) # noqa: F541
- # x is not in subgraph, create a new placeholder for subgraph
- with self.graph.inserting_before(None):
- placeholder_node = self.graph.placeholder(x.name, type_expr=x.type)
- # copy all meta fields, even if some fields might be irrelevant for
- # the placeholder node
- placeholder_node.meta = copy.copy(x.meta)
- self.node_to_placeholder[x] = placeholder_node
- def copy_sym_call_function(self, x):
- # This only exists because we deduplicate sym_size nodes in the flat export graph,
- # and if preserve_module_call_signature is set, we may not be able to pass sym_size
- # nodes, or their downstream users, as inputs to submodule calls.
- # To avoid this we copy these call_function nodes with sym_type results.
- # This should however only be done for sym_type nodes - call_function nodes on tensors
- # should not be deduplicated in the first place.
- args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args)
- kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs)
- node = self.graph.call_function(x.target, args, kwargs)
- node.meta = copy.copy(x.meta)
- self.node_map[x] = node
- return node
- def remap_input(self, x):
- if x.graph is not self.flat_graph:
- raise AssertionError(
- "expected x.graph to be flat_graph, got different graph"
- ) # noqa: F541
- if x in self.node_map:
- return self.node_map[x]
- self.print(f"remap_input({x})")
- if x in self.node_to_placeholder:
- return self.node_to_placeholder[x]
- elif (
- x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None
- # allow placeholder creation if we are not preserving module call signature
- ):
- self.add_placeholder(x)
- if self.parent_call_module is not None:
- # Important to *prepend* the output to match how we are
- # inserting placeholder nodes.
- with self.parent.graph.inserting_before(self.parent_call_module):
- self.parent_call_module.insert_arg(0, self.parent.remap_input(x))
- return self.node_to_placeholder[x]
- elif x.op == "call_function" and (
- x.target
- in (
- torch.ops.aten.sym_size.int,
- torch.ops.aten.item.default,
- torch.ops.aten.unbind.int,
- torch.ops.aten.sum.dim_IntList,
- torch.ops.aten.view.default,
- torch.ops.aten.diff.default,
- )
- or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator")
- ):
- # export deduplicates sym_size nodes, and may need to re-copy them
- # if module call signature needs to be preserved
- self.copy_sym_call_function(x)
- return self.node_map[x]
- elif self.module_call_graph.get(self.fqn) is not None:
- # x is reading the intermediate value of a mutation, so record it;
- # later we will find where it was created and perform the update
- return self.ivals.read(self, x) # type: ignore[operator, union-attr]
- else:
- raise RuntimeError(
- f"Could not run remap_input() on op type: {x.op} for node {x}"
- )
- def uplift_common_custom_metadata(self) -> None:
- # Copy custom metadata if all nodes have same custom metadata
- custom_meta = None
- for node in self.node_map.values():
- curr_meta = node.meta.get("custom", {})
- if custom_meta is None:
- # first node
- custom_meta = curr_meta
- continue
- if curr_meta != custom_meta:
- custom_meta = {}
- break
- if custom_meta:
- # Lift common custom metadata to parent node and clear children node's custom metadata
- if self.parent_call_module is None:
- raise AssertionError(
- "parent_call_module must not be None when uplifting custom metadata"
- )
- self.parent_call_module.meta["custom"] = custom_meta
- for node in self.node_map.values():
- del node.meta["custom"]
- def finalize_outputs(self):
- self.created_modules.pop(self.fqn, None)
- orig_outputs = []
- signature = self.module_call_graph.get(self.child_fqn)
- if signature is not None and self.parent is not None:
- for output in signature.outputs:
- if isinstance(
- output,
- (
- TensorArgument,
- SymIntArgument,
- SymBoolArgument,
- SymFloatArgument,
- ConstantArgument,
- ),
- ):
- if output.name in self.seen_nodes:
- orig_outputs.append(self.seen_nodes[output.name])
- else:
- orig_outputs.append(None)
- else:
- raise RuntimeError(
- f"Unsupported data type for output node: {output}"
- )
- def get_actual_output_node(output):
- if output is None:
- return None
- seen_node = self.seen_nodes[output.name]
- if seen_node in self.node_map:
- return self.node_map[seen_node]
- elif seen_node in self.node_to_placeholder:
- return self.node_to_placeholder[seen_node]
- else:
- raise RuntimeError(
- f"Could not find output node {output}. Graph: {self.graph}"
- )
- tree_out_node = _generate_unflatten(
- self.module,
- tuple(get_actual_output_node(output) for output in orig_outputs),
- signature.out_spec,
- )
- parent_out: torch.fx.Node | None = _generate_flatten_spec(
- self.parent.module, self.parent_call_module, signature.out_spec
- )
- graph_outputs: torch.fx.Node | list[torch.fx.Node] = tree_out_node
- else:
- graph_outputs = []
- # Iterate through nodes we have copied into self.graph.
- for orig_node in self.node_map:
- for user_node in orig_node.users:
- if user_node.name not in self.seen_nodes:
- # external user node, need to expose as an output
- orig_outputs.append(orig_node)
- graph_outputs.append(self.node_map[orig_node])
- break
- parent_out = self.parent_call_module
- if len(graph_outputs) == 1:
- graph_outputs = graph_outputs[0]
- if not isinstance(graph_outputs, (list, torch.fx.Node)):
- raise AssertionError(
- f"expected graph_outputs to be list or torch.fx.Node, got {type(graph_outputs)}"
- )
- self.graph.output(graph_outputs)
- # Rewrite outputs in parent module
- if parent_out is None:
- return
- parent_out.meta["val"] = (
- graph_outputs.meta.get("val")
- if isinstance(graph_outputs, torch.fx.Node)
- else [o.meta.get("val") for o in graph_outputs]
- )
- self.uplift_common_custom_metadata()
- if len(orig_outputs) == 1 and signature is None:
- self.parent.node_map[orig_outputs[0]] = parent_out
- else:
- for i, orig_output in enumerate(orig_outputs):
- if orig_output is None:
- continue
- # Use Proxy to record getitem access.
- proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index]
- proxy_out.meta["val"] = orig_output.meta.get("val")
- self.parent.node_map[orig_output] = proxy_out
- def copy_node(self, node):
- self.print("copying", node.format_node())
- self.node_map[node] = self.graph.node_copy(node, self.remap_input)
- self.seen_nodes[node.name] = node
- def run_outer(self):
- for i, node in enumerate(self.flat_graph.nodes):
- self.print(i, node.meta.get("nn_module_stack"), node.format_node())
- # Copy all graph inputs
- node_idx: int = 0
- node = self.nodes[node_idx]
- while node.op == "placeholder":
- self.copy_node(node)
- node_idx += 1
- node = self.nodes[node_idx]
- self.run_from(node_idx)
- # Copy graph outputs
- for node in self.flat_graph.nodes:
- if node.op == "output":
- self.copy_node(node)
- def print(self, *args, **kwargs):
- if self.verbose:
- # pyrefly: ignore [not-iterable]
- print(*args, **kwargs)
- def run_from(self, node_idx):
- module_idx = 0
- # Walk through the graph, building up a new graph with the right submodules
- while node_idx < len(self.nodes):
- node = self.nodes[node_idx]
- if node.op == "placeholder":
- raise AssertionError(f"unexpected placeholder node at index {node_idx}")
- self.print()
- self.print("STEP", node_idx, node.format_node())
- self.print(self.module_stack)
- depth = len(self.module_stack)
- if node.op == "output":
- if depth == 1:
- # We want the output node of the original graph to be handled
- # specially by the outermost stack frame (in run_outer). So
- # skip finalization here.
- return node_idx
- # We've reached the end of the graph. Wrap up all the existing stack frames.
- self.finalize_outputs()
- return node_idx
- if len(node.meta.get("nn_module_stack", {})) == 0:
- raise RuntimeError(f"Unable to find nn_module_stack for node {node}")
- nn_module_stack = node.meta["nn_module_stack"]
- from torch._export.passes._node_metadata_hook import (
- _EMPTY_NN_MODULE_STACK_KEY,
- )
- if (
- len(nn_module_stack) == 1
- and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack
- ):
- # Empty case from the node_metadata_hook
- node_module_stack = self.module_stack
- else:
- node_module_stack = [
- (
- path,
- ty if path else None,
- int(k.split("@")[-1]) if "@" in k else 0,
- )
- for k, (path, ty) in node.meta["nn_module_stack"].items()
- ]
- if node_module_stack[:depth] != self.module_stack:
- # This means that the current module is done executing and the
- # current node is the beginning of a new module.
- #
- # In this case, we should finalize this module and return without
- # incrementing the node counter.
- self.finalize_outputs()
- self.print("outlining", self.fqn)
- self.print(self.graph)
- return node_idx
- if node_module_stack is None:
- raise AssertionError("node_module_stack must not be None")
- if _is_prefix(self.module_stack, node_module_stack):
- # This means that the current node represents the execution of a new
- # module.
- next_module = node_module_stack[depth]
- self.print("Creating new stack frame for", next_module)
- # Run a nested version of module outliner from the current node
- # counter. Once it is complete, continue from that point.
- next_module_key = list(node.meta["nn_module_stack"].keys())[depth]
- node_idx = _ModuleFrame(
- self.flat_graph,
- self.nodes,
- self.seen_nodes,
- self.seen_modules,
- self.seen_attrs,
- self.created_modules,
- self,
- self.module_stack + [next_module],
- next_module_key.split("@")[0],
- self.module_call_graph,
- ).run_from(node_idx)
- module_idx += 1
- continue
- # The only remaining possibility is that we are in the right stack
- # frame. Copy the node into this frame's graph and increment the node counter.
- if node_module_stack != self.module_stack:
- raise AssertionError(
- f"expected node_module_stack {node_module_stack} to equal module_stack {self.module_stack}"
- )
- if node.op == "get_attr":
- # this must be a graph argument for a HOP
- self.seen_attrs[self.child_fqn].add(node.target)
- self.copy_node(node)
- # pyrefly: ignore [unsupported-operation]
- node_idx += 1
- @dataclass
- class _SubmoduleEntry:
- parent_fqn: str
- parent_module: torch.nn.Module
- parent_call_module: torch.fx.Node
- fqn: str
- call_idx: int
- module: torch.nn.Module
- def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
- seen_nodes: dict[str, torch.fx.Node] = {}
- seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
- seen_attrs: dict[str, set[str]] = defaultdict(set)
- created_modules: dict[str, torch.nn.Module] = {}
- _ModuleFrame(
- orig_graph,
- tuple(orig_graph.nodes),
- seen_nodes,
- seen_modules,
- seen_attrs,
- created_modules,
- None,
- [("", None, 0)],
- "",
- {
- entry.fqn: entry.signature
- for entry in root_module.module_call_graph
- if entry.signature
- },
- module=root_module,
- ).run_outer()
- return seen_modules, seen_attrs
- def _reorder_submodules(
- parent: torch.nn.Module, fqn_order: dict[str, int], prefix: str = ""
- ):
- # TODO Can be optimized by adding submodules ahead of time.
- if prefix == "":
- for fqn in list(fqn_order.keys())[1:]:
- if _get_submodule(parent, fqn) is None:
- _add_submodule(parent, fqn, torch.nn.Module())
- children = []
- for name, child in list(parent._modules.items()):
- if child is None:
- continue
- fqn = prefix + name
- _reorder_submodules(child, fqn_order, prefix=fqn.split("@")[0] + ".")
- delattr(parent, name)
- children.append((fqn_order[fqn], name, child))
- children.sort(key=operator.itemgetter(0))
- for _, name, child in children:
- parent.register_module(name, child)
- class _IVals:
- """
- Collect the intermediate values of mutations in a graph.
- Example: in the following graph, suppose that buf_in and buf_out
- are the input and output values of a buffer.
- buf_in = placeholder()
- ...
- ival1 = f0(buf_in, ...) # inside self.n0(...)
- ...
- ival2 = f1(ival1, ...) # inside self.n1(...)
- ...
- buf_out = f2(ival2, ...) # inside self.n2(...)
- return buf_out, ...
- Here ival1 and ival2 are intermediate values created inside
- calls to n0 and n1 respectively, and used inside calls to
- n1 and n2 respectively.
- """
- def __init__(self):
- # for each fqn, set of node names corresponding to intermediate values
- self.node_names_by_fqn = defaultdict(set)
- def _is_mutable(self, target):
- if isinstance(target, torch._ops.OpOverload):
- return target._schema.is_mutable
- return False
- def read(self, mf, node):
- """
- Read state corresponding to a given intermediate value.
- """
- # we can assume that the node must be from a mutation
- if node.op != "call_function":
- raise AssertionError(
- f"expected node.op to be 'call_function', got {node.op!r}"
- )
- b = self._is_mutable(node.target)
- print("Checking mutability", node.target, b)
- if not b:
- # so the mutation was functionalized;
- # we will apply the original mutation later (see below)
- fqn, _ = next(reversed(node.meta["nn_module_stack"].values()))
- self.node_names_by_fqn[fqn].add(node.name)
- return mf.remap_input(node.args[0])
- def update(self, partitions):
- """
- Update states corresponding to intermediate values that were read.
- """
- for shared_submodules in partitions:
- for entry in shared_submodules:
- graph = entry.module.graph
- node_names = self.node_names_by_fqn[entry.fqn]
- nodes = [n for n in graph.nodes if n.name in node_names]
- for node in nodes:
- # so node must be from a functionalized mutation;
- # we perform the original mutation now
- with graph.inserting_after(node):
- new_node = graph.create_node(
- "call_function",
- torch.ops.aten.copy_.default,
- (node.args[0], node),
- )
- new_node.meta = copy.copy(node.meta)
- def _copy_graph_attrs(
- gm: torch.fx.GraphModule,
- root_module: UnflattenedModule,
- seen_attrs: dict[str, set[str]],
- ):
- for child_fqn, names in seen_attrs.items():
- module = _get_attr(root_module, child_fqn) if child_fqn else root_module
- for name in names:
- val = getattr(gm, name)
- setattr(module, name, val)
- def _deduplicate_modules(partitions):
- redirected_call_indices = {}
- for shared_submodules in partitions:
- for i, entry in enumerate(shared_submodules):
- child_fqn = _call_name(entry.fqn, entry.call_idx)
- target = _compute_accessor(entry.parent_fqn, child_fqn)
- deduplicated = False
- # Iterate over all previously seen modules, and deduplicate if possible
- for seen in shared_submodules[:i]:
- if _check_graph_equivalence(seen.module, entry.module):
- parent = entry.parent_module
- # Since graphs are equivalent, we can deduplicate.
- # There are two cases.
- if seen.fqn == entry.fqn:
- # Case 1: The current module has the same fqn as the seen module.
- # In this case we have generated a call name that can be optimized away.
- # So we remove the current module from the hierarchy and replace
- # the current call name with the seen call name in the parent graph.
- *prefix, name = target.split(".")
- _get_attr_via_attr_list(parent, prefix)._modules.pop(name)
- seen_child_fqn = _call_name(seen.fqn, seen.call_idx)
- seen_target = _compute_accessor(
- entry.parent_fqn, seen_child_fqn
- )
- entry.parent_call_module.target = seen_target
- redirected_call_indices[child_fqn] = seen_child_fqn
- break
- elif not deduplicated:
- # Case 2: The current module has a different fqn than the seen module.
- # In this case we replace the current module with the seen module.
- # There should be nothing pointing to the current module any more,
- # so it can be garbage collected.
- # NOTE: We *do not* replace the current call name with the seen call name
- # in the parent graph, because this will lose information on which fqn
- # was actually called. However, it is possible that the current call name
- # will be optimized away when we find another seen module with the same fqn,
- # so we do not break out of the loop yet.
- parent.set_submodule(target, seen.module)
- deduplicated = True
- return redirected_call_indices
- def _sink_params(
- module: torch.nn.Module,
- inputs_to_state: dict[str, list[str]],
- scope: list[str],
- module_id_to_inputs_removed: dict[int, set[str]] | None = None,
- ):
- """Sink params, buffers, and constants from graph inputs into get_attr nodes.
- Exported modules are purely functional, so they pass their parameters and
- buffers in as inputs to the graph.
- To replicate eager's semantics, we need to get them from the module state
- via get_attr instead.
- module: GraphModule, potentially containing nested submodules.
- inputs_to_state: mapping graph input names to the corresponding key in the state_dict.
- scope: tracks where we are in the module hierarchy, so that we can emit the
- right `getattr(self, "foo.bar")` calls, etc.
- module_id_to_inputs_removed: records inputs removed by child modules, mapping
- the module object id to the list of placeholder node names in the child module
- that were removed.
- """
- if module_id_to_inputs_removed is None:
- module_id_to_inputs_removed = defaultdict(set)
- if id(module) in module_id_to_inputs_removed:
- return {id(module): module_id_to_inputs_removed[id(module)]}
- # We need to use _modules here instead of named_children(), because we
- # explicitly want duplicate modules to show up in the traversal.
- for name, submodule in module._modules.items():
- submod_id_to_inputs_removed = _sink_params(
- cast("torch.nn.Module", submodule),
- inputs_to_state,
- scope + [name],
- module_id_to_inputs_removed,
- )
- for k, v in submod_id_to_inputs_removed.items():
- module_id_to_inputs_removed[k].update(v)
- graph = getattr(module, "graph", None)
- if graph is None or len(graph.nodes) == 0:
- # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)
- return module_id_to_inputs_removed
- if not isinstance(graph, torch.fx.Graph):
- raise AssertionError(f"expected graph to be torch.fx.Graph, got {type(graph)}")
- inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes))
- the_last_input = None if len(inputs) == 0 else inputs[-1]
- # Also remove from call_module nodes
- call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes)
- for node in call_module_nodes:
- submodule = _get_attr(module, node.target)
- # remove placeholder from call_module node arguments, only if we've
- # erased the placeholder node in the corresponding _sink_params() call
- if submodule is not None and id(submodule) in module_id_to_inputs_removed:
- node.args = tuple(
- filter(
- lambda n: n.name not in module_id_to_inputs_removed[id(submodule)],
- node.args,
- )
- )
- # Filter out inputs_to_state corresponding to current scope.
- inputs_to_state_of_scope: dict[torch.fx.Node, list[str]] = {}
- for node in inputs:
- if node.name not in inputs_to_state:
- continue
- state_name = None
- for sn in inputs_to_state[node.name]:
- sn_split = sn.split(".")
- if sn_split[: len(scope)] == [x.split("@")[0] for x in scope]:
- state_name = sn_split
- break
- # If there's a mismatch between scope name and state name, then
- # there must be multiple scopes pointing to the same state name,
- # meaning some modules are shared. In such case, we can simply skip
- # updating the current node because another later iteration will
- # take care of this input node when the unique match between scope
- # and state name occurs. To make sure this always happen, we should
- # enforce the invariant that no placeholder node in the unflattened
- # graph appears in inputs_to_state dict, which means all the extra
- # input nodes have been handled.
- if state_name is None:
- continue
- inputs_to_state_of_scope[node] = state_name
- # Record name of remove inputs for return purpose.
- inputs_removed: set[str] = set()
- for node, state_name in inputs_to_state_of_scope.items():
- if len(node.users) > 0:
- attr_path = state_name[len(scope) :]
- state_attr = _get_attr_via_attr_list(module, attr_path)
- if not isinstance(state_attr, (torch.Tensor, torch.ScriptObject)):
- raise AssertionError(
- f"expected state_attr to be torch.Tensor or torch.ScriptObject, got {type(state_attr)}"
- )
- # Make sure the newly created get_attr node is placed after the last placeholder node
- with graph.inserting_after(the_last_input):
- new_node = graph.create_node("get_attr", ".".join(attr_path))
- node.replace_all_uses_with(new_node, propagate_meta=True)
- graph.erase_node(node)
- inputs_removed.add(node.name)
- if isinstance(module, InterpreterModule):
- module.finalize()
- return {id(module): inputs_removed}
|