unflatten.py 74 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899
  1. # mypy: allow-untyped-defs
  2. import abc
  3. import copy
  4. import logging
  5. import operator
  6. import re
  7. from collections import defaultdict
  8. from collections.abc import Callable
  9. from contextlib import contextmanager
  10. from copy import deepcopy
  11. from dataclasses import dataclass
  12. from enum import Enum
  13. from typing import Any, cast
  14. import torch
  15. import torch.fx._pytree as fx_pytree
  16. import torch.utils._pytree as pytree
  17. from torch._library.fake_class_registry import FakeScriptObject
  18. from torch.export import ExportedProgram
  19. from torch.export._tree_utils import reorder_kwargs
  20. from torch.export.exported_program import (
  21. ConstantArgument,
  22. ExportGraphSignature,
  23. InputKind,
  24. ModuleCallSignature,
  25. SymBoolArgument,
  26. SymFloatArgument,
  27. SymIntArgument,
  28. TensorArgument,
  29. )
  30. from torch.fx._symbolic_trace import is_fx_symbolic_tracing
  31. from torch.fx.graph_module import _get_attr, _get_attr_via_attr_list, _print_readable
  32. from torch.utils._pytree import GetAttrKey, SequenceKey
  33. from ._remove_effect_tokens_pass import _remove_effect_tokens
  34. log = logging.getLogger(__name__)
  35. __all__ = [
  36. "FlatArgsAdapter",
  37. "InterpreterModule",
  38. "InterpreterModuleDispatcher",
  39. "UnflattenedModule",
  40. "unflatten",
  41. ]
  42. class _AttrKind(Enum):
  43. PARAMETER = "parameter"
  44. BUFFER = "buffer"
  45. CONSTANT = "constant"
  46. MODULE = "module"
  47. @dataclass(frozen=True)
  48. class _TensorID:
  49. """Custom tensor identifier containing storage, stride, and size information."""
  50. untyped_storage: torch.UntypedStorage
  51. stride: tuple
  52. size: tuple
  53. storage_offset: int
  54. RUN_WITH_INTERPRETER = True
  55. @contextmanager
  56. def _disable_interpreter():
  57. global RUN_WITH_INTERPRETER
  58. old_flag = RUN_WITH_INTERPRETER
  59. RUN_WITH_INTERPRETER = False
  60. try:
  61. yield
  62. finally:
  63. RUN_WITH_INTERPRETER = old_flag
  64. # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
  65. # This installs empty Modules where none exist yet if they are subpaths of target
  66. def _assign_attr(
  67. from_obj: torch.Tensor | torch.ScriptObject | torch.nn.Module,
  68. to_module: torch.nn.Module,
  69. target: str,
  70. attr_kind: _AttrKind,
  71. persistent: bool = True,
  72. ):
  73. *prefix, field = target.split(".")
  74. # We need to generate all submodules of `to_module` that are at `prefix` and
  75. # variants of `prefix` that differ only by call name. All of these submodules
  76. # will then be assigned `from_obj` at `field` so that they can share this attribute.
  77. # For example, if target is foo.bar.f, foo has another call name foo@1,
  78. # and bar has other call names bar@1, bar@2, then we will assign f to
  79. # foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2.
  80. to_modules = {to_module}
  81. for item in prefix:
  82. ts: set[torch.nn.Module] = set()
  83. for to_module in to_modules:
  84. if not hasattr(to_module, item):
  85. setattr(to_module, item, torch.nn.Module())
  86. ts.update(
  87. t_call # type: ignore[misc]
  88. for k, t_call in to_module._modules.items()
  89. if _is_call_name(k, item)
  90. )
  91. to_modules = ts
  92. for to_module in to_modules:
  93. if attr_kind == _AttrKind.PARAMETER:
  94. if not isinstance(from_obj, torch.nn.Parameter):
  95. raise AssertionError(
  96. f"expected torch.nn.Parameter for PARAMETER attr_kind, got {type(from_obj)}"
  97. )
  98. to_module.register_parameter(field, from_obj)
  99. elif attr_kind == _AttrKind.BUFFER:
  100. if not isinstance(from_obj, torch.Tensor):
  101. raise AssertionError(
  102. f"expected torch.Tensor for BUFFER attr_kind, got {type(from_obj)}"
  103. )
  104. to_module.register_buffer(field, from_obj, persistent=persistent)
  105. elif attr_kind == _AttrKind.CONSTANT:
  106. if isinstance(from_obj, FakeScriptObject):
  107. raise AssertionError(
  108. "FakeScriptObject should only exist during tracing."
  109. )
  110. if not isinstance(
  111. from_obj,
  112. (
  113. torch.Tensor,
  114. torch.ScriptObject,
  115. ),
  116. ):
  117. raise AssertionError(
  118. f"expected torch.Tensor or torch.ScriptObject for CONSTANT attr_kind, got {type(from_obj)}"
  119. )
  120. setattr(to_module, field, from_obj)
  121. elif attr_kind == _AttrKind.MODULE:
  122. if not isinstance(from_obj, torch.nn.Module):
  123. raise AssertionError(
  124. f"expected torch.nn.Module for MODULE attr_kind, got {type(from_obj)}"
  125. )
  126. setattr(to_module, field, from_obj)
  127. class _SubmoduleBase:
  128. _ty: str | None
  129. def type_name(self) -> str | None:
  130. """
  131. Subclass of this class - InterpreterModule, InterpreterModuleDispatcher, represents
  132. corresponding model in eager model. To get this type information for those modules
  133. in eager model we need to use this method.
  134. """
  135. return self._ty
  136. class InterpreterModule(_SubmoduleBase, torch.nn.Module):
  137. """A module that uses torch.fx.Interpreter to execute instead of the usual
  138. codegen that GraphModule uses. This provides better stack trace information
  139. and makes it easier to debug execution.
  140. """
  141. graph_module: torch.fx.GraphModule | None
  142. def __init__(
  143. self,
  144. graph: torch.fx.Graph,
  145. ty: str | None = None,
  146. ):
  147. super().__init__()
  148. self.graph = graph
  149. self._ty = ty
  150. self.graph.owning_module = self # type: ignore[assignment]
  151. self._run_with_interpreter = RUN_WITH_INTERPRETER
  152. def forward(self, *args, **kwargs):
  153. if self.graph_module is None:
  154. raise AssertionError("Didn't finalize this InterpreterModule")
  155. if not is_fx_symbolic_tracing() and (
  156. torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter
  157. ):
  158. # Dynamo cannot trace through torch.fx.Interpreter, so fall back to
  159. # GraphModule codegen in this instance.
  160. # Patch the codegened forward to run with this InterpreterModule,
  161. # so attribute accesses, etc. are on this module instead.
  162. return type(self.graph_module).forward(self, *args, **kwargs)
  163. else:
  164. if kwargs:
  165. # Handle **kwargs. FX only natively supports positional
  166. # arguments (through placeholders). So in order to pass in
  167. # kwargs, we must correspond the names of the placeholders with
  168. # the keys in the kwarg dict.
  169. arg_list = list(args)
  170. kwarg_names = self.arg_names[len(arg_list) :]
  171. arg_list.extend(
  172. kwargs[kwarg_name]
  173. for kwarg_name in kwarg_names
  174. if kwarg_name in kwargs
  175. )
  176. # Assert that the kwargs passed in exactly match the positional
  177. # arguments specified by the GraphModule. This should be
  178. # guaranteed by the unflattening process.
  179. if len(kwarg_names) != len(kwargs):
  180. raise AssertionError(
  181. f"kwarg_names length {len(kwarg_names)} does not match kwargs length {len(kwargs)}"
  182. )
  183. if len(arg_list) != len(self.arg_names):
  184. raise AssertionError(
  185. f"arg_list length {len(arg_list)} does not match arg_names length {len(self.arg_names)}"
  186. )
  187. args = tuple(arg_list)
  188. return torch.fx.Interpreter(self, graph=self.graph).run(
  189. *args, enable_io_processing=False
  190. )
  191. def finalize(self):
  192. # We need to "finalize" because GraphModule populates its own state_dict
  193. # based on the get_attrs observed in the graph. So we need to fully
  194. # construct the graph and call _sink_params before generating this
  195. # GraphModule.
  196. # need to set `graph_module` directly on the dict to avoid it getting
  197. # registered as a submodule.
  198. self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
  199. self.graph.lint()
  200. # Cache arg names for kwarg handling (see forward())
  201. self.arg_names = []
  202. for node in self.graph.nodes:
  203. if node.op == "placeholder":
  204. self.arg_names.append(node.target)
  205. def print_readable(
  206. self,
  207. print_output=True,
  208. include_stride=False,
  209. include_device=False,
  210. colored=False,
  211. ):
  212. return _print_readable(
  213. self,
  214. "InterpreterModule",
  215. print_output,
  216. include_stride,
  217. include_device,
  218. colored,
  219. )
  220. class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module):
  221. """
  222. A module that carries a sequence of InterpreterModules corresponding to
  223. a sequence of calls of that module. Each call to the module dispatches
  224. to the next InterpreterModule, and wraps back around after the last.
  225. """
  226. def __init__(self, attrs: set[str], call_modules: list[InterpreterModule]):
  227. super().__init__()
  228. if not call_modules:
  229. raise AssertionError("call_modules must not be empty")
  230. self._modules = call_modules[0]._modules
  231. for accessor in attrs:
  232. setattr(self, accessor, getattr(call_modules[0], accessor))
  233. self._ty = call_modules[0]._ty
  234. self._call_modules = call_modules
  235. self._num_calls = 0
  236. def forward(self, *args, **kwargs):
  237. call_module = self._call_modules[self._num_calls]
  238. self._num_calls = (self._num_calls + 1) % len(self._call_modules)
  239. try:
  240. return call_module(*args, **kwargs)
  241. except Exception:
  242. self._num_calls = 0
  243. raise
  244. def call_modules(self):
  245. return self._call_modules
  246. def print_readable(
  247. self,
  248. print_output=True,
  249. include_stride=False,
  250. include_device=False,
  251. colored=False,
  252. ):
  253. outputs = [
  254. mod.print_readable(
  255. print_output,
  256. include_stride,
  257. include_device,
  258. colored,
  259. )
  260. for mod in self._call_modules
  261. ]
  262. return "\n".join(outputs)
  263. class FlatArgsAdapter(abc.ABC):
  264. """
  265. Adapts input arguments with ``input_spec`` to align ``target_spec``.
  266. """
  267. @abc.abstractmethod
  268. def adapt(
  269. self,
  270. target_spec: pytree.TreeSpec,
  271. input_spec: pytree.TreeSpec,
  272. input_args: list[Any],
  273. metadata: dict[str, Any] | None = None,
  274. obj: Any | None = None,
  275. ) -> list[Any]:
  276. """NOTE: This adapter may mutate given ``input_args_with_path``."""
  277. ...
  278. def get_flat_arg_paths(self) -> list[str]:
  279. """Returns a list of paths that are used to access the flat args."""
  280. return []
  281. class UnflattenedModule(_SubmoduleBase, torch.nn.Module):
  282. def __init__(
  283. self,
  284. export_module: ExportedProgram,
  285. flat_args_adapter: FlatArgsAdapter | None = None,
  286. ):
  287. super().__init__()
  288. if export_module.graph_signature.backward_signature is not None:
  289. raise ValueError("Unflattening on JointExportModule NYI")
  290. def _id(obj):
  291. """Returns _TensorID dataclass for tensors, otherwise id()."""
  292. if isinstance(obj, torch.Tensor):
  293. return _TensorID(
  294. untyped_storage=obj.untyped_storage(),
  295. stride=obj.stride(),
  296. size=obj.size(),
  297. storage_offset=obj.storage_offset(), # type: ignore[arg-type]
  298. )
  299. return id(obj)
  300. fqn_list = [entry.fqn for entry in export_module.module_call_graph]
  301. if fqn_list[0] != "":
  302. raise AssertionError(
  303. f"expected first fqn to be empty string, got {fqn_list[0]!r}"
  304. )
  305. export_graph = deepcopy(export_module.graph)
  306. self.graph_signature = deepcopy(export_module.graph_signature)
  307. self.graph = torch.fx.Graph()
  308. self.graph.owning_module = self # type: ignore[assignment]
  309. self.module_call_graph = deepcopy(export_module.module_call_graph)
  310. self.flat_args_adapter = flat_args_adapter
  311. self.meta = export_module.graph_module.meta
  312. self.meta["unflattened_module"] = self
  313. # Flag to indicate whether args have been adapted.
  314. self.adapted = False
  315. self._run_with_interpreter = RUN_WITH_INTERPRETER
  316. _inplace_buffer_and_input_mutations(export_graph, self.graph_signature)
  317. _fix_nn_module_stacks(export_graph)
  318. self._ty = _root_module_type(export_graph)
  319. self.ivals = _IVals()
  320. # for any intermediate value of a mutation that is read, track the mutation
  321. seen_modules, seen_attrs = _outline_submodules(export_graph, self)
  322. # for each read intermediate value of a mutation, find where it was created,
  323. # and perform the mutation
  324. self.ivals.update(seen_modules.values())
  325. # move attributes that correspond to graph arguments for HOPs
  326. # from exported program to unflattened submodules
  327. _copy_graph_attrs(export_module._graph_module, self, seen_attrs)
  328. self.range_constraints = export_module.range_constraints
  329. self.equality_constraints: list = []
  330. # aliasing/unused param or buffer issues:
  331. # in strict-mode export, dynamo export will deduplicate aliased tensors,
  332. # and ignore unused tensors. For aliasing, this causes issues when some aliases
  333. # are unused, and we're unable to match the placeholder node to the correct FQN.
  334. # This leads to the graph signature potentially having the wrong target FQN,
  335. # and downstream issues where parameters are assigned to the wrong target attribute,
  336. # mismatching the relevant placeholder node in the unflattened module.
  337. # To resolve this we restore (_assign_attr) all aliased/unused tensors in
  338. # the state_dict as module attributes, but only keep the used tensors in the
  339. # graph's forward pass (_sink_params).
  340. state_dict = export_module.state_dict
  341. assigned_params: set[str] = set() # tracking unused params
  342. id_to_param: dict[
  343. int | _TensorID, torch.nn.Parameter
  344. ] = {} # handling weight-sharing
  345. for name in self.graph_signature.parameters: # this loop adds used params
  346. param = state_dict[name]
  347. if _id(param) not in id_to_param:
  348. id_to_param[_id(param)] = torch.nn.Parameter(
  349. param.clone(), requires_grad=param.requires_grad
  350. )
  351. _assign_attr(
  352. id_to_param[_id(param)],
  353. self,
  354. name,
  355. attr_kind=_AttrKind.PARAMETER,
  356. )
  357. assigned_params.add(name)
  358. non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
  359. assigned_buffers: set[str] = set() # tracking unused buffers
  360. id_to_buffer: dict[int | _TensorID, tuple[torch.nn.Parameter, bool]] = {}
  361. for name in self.graph_signature.buffers: # this loop adds used buffers
  362. if name in non_persistent_buffers:
  363. persistent = False
  364. buffer = export_module.constants[name]
  365. else:
  366. persistent = True
  367. buffer = state_dict[name]
  368. if _id(buffer) not in id_to_buffer:
  369. id_to_buffer[_id(buffer)] = (buffer.clone(), persistent)
  370. _assign_attr(
  371. id_to_buffer[_id(buffer)][0],
  372. self,
  373. name,
  374. attr_kind=_AttrKind.BUFFER,
  375. persistent=persistent,
  376. )
  377. assigned_buffers.add(name)
  378. # restore aliased/unused params and buffers
  379. # these appear in state dict but not graph signature
  380. for name, tensor in state_dict.items():
  381. if name in assigned_params or name in assigned_buffers: # already assigned
  382. continue
  383. is_buffer = False
  384. if _id(tensor) in id_to_buffer or not isinstance(
  385. tensor, torch.nn.Parameter
  386. ): # aliased buffer
  387. is_buffer = True
  388. if is_buffer:
  389. if (
  390. _id(tensor) not in id_to_buffer
  391. ): # this is completely unused (not weight-sharing)
  392. id_to_buffer[_id(tensor)] = (
  393. tensor,
  394. True,
  395. ) # assign to respect original model
  396. _assign_attr(
  397. id_to_buffer[_id(tensor)][0],
  398. self,
  399. name,
  400. attr_kind=_AttrKind.BUFFER,
  401. persistent=True,
  402. )
  403. else:
  404. if _id(tensor) not in id_to_param: # this is unused
  405. id_to_param[_id(tensor)] = tensor
  406. _assign_attr(
  407. id_to_param[_id(tensor)],
  408. self,
  409. name,
  410. attr_kind=_AttrKind.PARAMETER,
  411. )
  412. # use id map so we don't double-clone aliased constants
  413. id_to_const: dict[int | _TensorID, torch.Tensor | torch._C.ScriptObject] = {}
  414. for fqn, constant in export_module.constants.items():
  415. if _id(constant) not in id_to_const:
  416. if isinstance(constant, torch.Tensor):
  417. constant = constant.clone()
  418. id_to_const[_id(constant)] = constant
  419. _constant = id_to_const[_id(constant)]
  420. _assign_attr(
  421. _constant,
  422. self,
  423. fqn,
  424. attr_kind=_AttrKind.CONSTANT,
  425. )
  426. # This is to handle parameters/buffers that point to the same tensor
  427. # object id -> list of (node_name, target_name)
  428. consts_map: dict[int | _TensorID, list[tuple[str, str]]] = defaultdict(list)
  429. consts_targets: set[str] = set()
  430. def add_to_consts_map(obj_id, node_name, target_name):
  431. name_list = consts_map[obj_id]
  432. name_list.append((node_name, target_name))
  433. # track aliased/unused params, buffers
  434. # prefer using untyped_storage() over id() when it's available
  435. added_params_buffers: set[str] = set()
  436. for s in self.graph_signature.input_specs:
  437. if s.kind == InputKind.PARAMETER or (
  438. s.kind == InputKind.BUFFER and s.persistent
  439. ):
  440. if not hasattr(s.arg, "name"):
  441. raise AssertionError(
  442. f"expected s.arg to have 'name' attribute, got {type(s.arg)}"
  443. )
  444. if not isinstance(s.target, str):
  445. raise AssertionError(
  446. f"expected s.target to be str, got {type(s.target)}"
  447. )
  448. add_to_consts_map(
  449. _id(export_module.state_dict[s.target]),
  450. s.arg.name,
  451. s.target,
  452. )
  453. consts_targets.add(s.target)
  454. added_params_buffers.add(s.target)
  455. elif (
  456. s.kind == InputKind.BUFFER
  457. and not s.persistent
  458. or s.kind == InputKind.CONSTANT_TENSOR
  459. or s.kind == InputKind.CUSTOM_OBJ
  460. ):
  461. if not hasattr(s.arg, "name"):
  462. raise AssertionError(
  463. f"expected s.arg to have 'name' attribute for kind {s.kind}, got {type(s.arg)}"
  464. )
  465. if not isinstance(s.target, str):
  466. raise AssertionError(
  467. f"expected s.target to be str for kind {s.kind}, got {type(s.target)}"
  468. )
  469. add_to_consts_map(
  470. _id(export_module.constants[s.target]),
  471. s.arg.name,
  472. s.target,
  473. )
  474. consts_targets.add(s.target)
  475. # add constants that are aliased and don't appear in graph signature
  476. for const_name, const in export_module.constants.items():
  477. if const_name not in consts_targets:
  478. const_id = _id(const)
  479. if const_id not in consts_map:
  480. raise AssertionError(
  481. f"constant {const_name!r} id not found in consts_map"
  482. )
  483. ph_name, _ = consts_map[const_id][0]
  484. add_to_consts_map(const_id, ph_name, const_name)
  485. added_params_buffers.add(s.target)
  486. # add aliased/unused params and buffers that don't appear in graph signature
  487. for fqn, tensor in export_module.state_dict.items():
  488. if fqn not in added_params_buffers:
  489. tensor_id = _id(tensor)
  490. if tensor_id not in consts_map:
  491. # completely unused (no weight-sharing), ignore.
  492. # this weight doesn't appear in graph module,
  493. # so won't cause FQN assignment issues
  494. continue
  495. ph_name, _ = consts_map[tensor_id][0]
  496. add_to_consts_map(tensor_id, ph_name, fqn)
  497. # node name -> list of possible targets
  498. inputs_to_state: dict[str, list[str]] = {}
  499. for node_target in consts_map.values():
  500. targets = [t[1] for t in node_target]
  501. for n, _ in node_target:
  502. inputs_to_state[n] = targets
  503. _sink_params(self, inputs_to_state, [])
  504. redirected_call_indices = _deduplicate_modules(seen_modules.values())
  505. fqn_list = [fqn for fqn in fqn_list if fqn not in redirected_call_indices]
  506. self._dispatch_modules(redirected_call_indices, consts_targets)
  507. fqn_list = [fqn for fqn in fqn_list if "@" not in fqn]
  508. # Cache so we don't have to compute this every time.
  509. # NOTE: this needs to be kept in sync with the placeholders in
  510. # self.graph, but currently we have no way to guarantee that.
  511. self.input_placeholders = [
  512. node for node in self.graph.nodes if node.op == "placeholder"
  513. ]
  514. self.check_input_constraints = True
  515. # TODO(zhxchen17) We can register modules ahead of time instead of reorder later.
  516. fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)}
  517. # In the case of legacy IR, we might be missing some modules from metadata.
  518. for name, _ in self.named_modules(remove_duplicate=False):
  519. if name not in fqn_order:
  520. fqn_order[name] = len(fqn_order)
  521. _reorder_submodules(self, fqn_order)
  522. self.graph.lint()
  523. self.finalize()
  524. def _print_graph(self):
  525. for fqn, mod in self.named_modules():
  526. print(fqn + ":")
  527. if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph):
  528. print(mod.graph)
  529. def _adapt_flat_args(self, flat_args, in_spec, input):
  530. signature = self.module_call_graph[0].signature
  531. if in_spec == signature.in_spec:
  532. return flat_args
  533. if self.flat_args_adapter is None:
  534. raise TypeError(
  535. "There is no flat args adapter specified. "
  536. "Are you sure you are calling this with the right arguments? "
  537. )
  538. else:
  539. flat_args = self.flat_args_adapter.adapt(
  540. target_spec=signature.in_spec,
  541. input_spec=in_spec,
  542. input_args=flat_args,
  543. metadata=self.meta,
  544. obj=input,
  545. )
  546. if len(flat_args) != signature.in_spec.num_leaves:
  547. raise TypeError(
  548. f"Flat args adaption failed, number of args mismatch "
  549. f"Adatped: {len(flat_args)} \n"
  550. f"Exported module: {signature.in_spec.num_leaves}"
  551. )
  552. return flat_args
  553. def process_forward_inputs(self, *args, **kwargs):
  554. signature = self.module_call_graph[0].signature
  555. reordered_kwargs = kwargs
  556. if kwargs:
  557. reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec)
  558. flat_args_with_path, in_spec = pytree.tree_flatten_with_path(
  559. (args, reordered_kwargs)
  560. )
  561. flat_args = [x[1] for x in flat_args_with_path]
  562. if is_fx_symbolic_tracing():
  563. return flat_args
  564. if in_spec != signature.in_spec:
  565. if not self.adapted:
  566. print(
  567. "Input treespec does not match with exported module's: \n"
  568. f"Input treespec: {in_spec}. ",
  569. f"Exported module treespec: {signature.in_spec}",
  570. )
  571. print("Adapting flat arg to match exported module's treespec")
  572. flat_args = self._adapt_flat_args(flat_args, in_spec, args)
  573. self.adapted = True
  574. if self.check_input_constraints:
  575. # Import here to avoid an unfortunate circular dependency.
  576. # TODO(suo): untangle this.
  577. from torch._export.utils import _check_input_constraints_for_graph
  578. if self.adapted is True:
  579. flat_arg_paths = (
  580. self.flat_args_adapter.get_flat_arg_paths()
  581. if self.flat_args_adapter
  582. else []
  583. )
  584. if flat_arg_paths and len(flat_arg_paths) != len(flat_args):
  585. raise AssertionError(
  586. f"flat_arg_paths length {len(flat_arg_paths)} does not match flat_args length {len(flat_args)}"
  587. )
  588. new_flat_args_with_path = [ # type: ignore[var-annotated]
  589. (
  590. (
  591. SequenceKey(idx=idx),
  592. GetAttrKey(
  593. name=flat_arg_paths[idx]
  594. if flat_arg_paths
  595. else "<unknown location>"
  596. ),
  597. ),
  598. arg,
  599. )
  600. for idx, arg in enumerate(flat_args)
  601. ]
  602. else:
  603. new_flat_args_with_path = flat_args_with_path # type: ignore[assignment]
  604. _check_input_constraints_for_graph(
  605. self.input_placeholders, new_flat_args_with_path, self.range_constraints
  606. )
  607. return flat_args
  608. def forward(self, *args, **kwargs):
  609. flat_args = self.process_forward_inputs(*args, **kwargs)
  610. signature = self.module_call_graph[0].signature
  611. if is_fx_symbolic_tracing():
  612. return_val = torch.fx.Interpreter(self, graph=self.graph).run(
  613. *flat_args, enable_io_processing=False
  614. )
  615. # For scalar return value, fx.Graph wraps in a tuple
  616. if isinstance(return_val, tuple) and len(return_val) == 1:
  617. return return_val[0]
  618. return return_val
  619. if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter:
  620. tree_out = type(self.graph_module).forward(self, *flat_args) # type: ignore[union-attr]
  621. else:
  622. tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
  623. *flat_args, enable_io_processing=False
  624. )
  625. return pytree.tree_unflatten(tree_out, signature.out_spec)
  626. def finalize(self):
  627. self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
  628. self.graph.lint()
  629. def _dispatch_modules(self, redirected_call_indices, consts_targets):
  630. """For a module whose call signatures are preserved, replace
  631. multiple modules corresponding to multiple calls to that module
  632. with a single dispatcher module that tracks which module to call.
  633. """
  634. # for each fqn whose module call signature is preserved,
  635. # map that fqn to a list of called modules
  636. called_modules = defaultdict(list)
  637. for entry in self.module_call_graph:
  638. if entry.fqn and entry.signature:
  639. # some modules were removed and their fqns redirected to other
  640. # fqns during deduplication
  641. fqn = entry.fqn
  642. mod = _get_attr(self, redirected_call_indices.get(fqn, fqn))
  643. base, idx = fqn.split("@") if "@" in fqn else [fqn, "0"]
  644. called_modules[base].append((int(idx), mod))
  645. attrs_map = defaultdict(set)
  646. for target in consts_targets:
  647. if "." in target:
  648. orig_fqn, name = target.rsplit(".", 1)
  649. attrs_map[orig_fqn].add(name)
  650. else:
  651. attrs_map[""].add(target)
  652. # replace multiple call modules with a single dispatcher module
  653. for orig_fqn, indexed_call_modules in called_modules.items():
  654. call_modules = [mod for _, mod in sorted(indexed_call_modules)]
  655. if len(call_modules) > 1:
  656. for i in range(len(call_modules)):
  657. fqn = _call_name(orig_fqn, i + 1)
  658. if fqn not in redirected_call_indices:
  659. *prefix, name = fqn.split(".")
  660. _get_attr_via_attr_list(self, prefix)._modules.pop(name)
  661. self.set_submodule(
  662. orig_fqn,
  663. InterpreterModuleDispatcher(attrs_map[orig_fqn], call_modules),
  664. )
  665. # elide call indices in call modules because they are
  666. # tracked automatically inside the dispatcher module
  667. def elide_call_indices(prefix, graph):
  668. for node in graph.nodes:
  669. if node.op == "call_module":
  670. fqn = node.target.split("@")[0]
  671. path = f"{prefix}.{fqn}" if prefix else fqn
  672. if path in called_modules:
  673. node.target = fqn
  674. for fqn, mod in self.named_modules(remove_duplicate=False):
  675. if hasattr(mod, "graph"):
  676. elide_call_indices(fqn, mod.graph)
  677. elif hasattr(mod, "_call_modules"):
  678. for mod_ in mod._call_modules:
  679. if not hasattr(mod_, "graph"):
  680. raise AssertionError(
  681. f"expected mod_ to have 'graph' attribute, got {type(mod_)}"
  682. )
  683. elide_call_indices(fqn, mod_.graph)
  684. def print_readable(
  685. self,
  686. print_output=True,
  687. include_stride=False,
  688. include_device=False,
  689. colored=False,
  690. ):
  691. return _print_readable(
  692. self,
  693. "UnflattenedModule",
  694. print_output,
  695. include_stride,
  696. include_device,
  697. colored,
  698. )
  699. def unflatten(
  700. module: ExportedProgram, flat_args_adapter: FlatArgsAdapter | None = None
  701. ) -> UnflattenedModule:
  702. """Unflatten an ExportedProgram, producing a module with the same module
  703. hierarchy as the original eager module. This can be useful if you are trying
  704. to use :mod:`torch.export` with another system that expects a module
  705. hierarchy instead of the flat graph that :mod:`torch.export` usually produces.
  706. .. note:: The args/kwargs of unflattened modules will not necessarily match
  707. the eager module, so doing a module swap (e.g. :code:`self.submod =
  708. new_mod`) will not necessarily work. If you need to swap a module out, you
  709. need to set the :code:`preserve_module_call_signature` parameter of
  710. :func:`torch.export.export`.
  711. Args:
  712. module (ExportedProgram): The ExportedProgram to unflatten.
  713. flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's.
  714. Returns:
  715. An instance of :class:`UnflattenedModule`, which has the same module
  716. hierarchy as the original eager module pre-export.
  717. """
  718. module = _remove_effect_tokens(module)
  719. m = UnflattenedModule(module, flat_args_adapter)
  720. # Disable process_forward_inputs as the adapter has many
  721. # non-dynamo-traceable behavior.
  722. m.process_forward_inputs = torch._dynamo.disable( # type: ignore[method-assign]
  723. m.process_forward_inputs,
  724. reason="do not trace into preprocessing the inputs",
  725. recursive=True,
  726. )
  727. return m
  728. def _inplace_buffer_and_input_mutations(
  729. graph: torch.fx.Graph,
  730. graph_signature: ExportGraphSignature,
  731. ) -> None:
  732. """Transform buffer and input mutations from their functionalized form
  733. into copy_ nodes in the graph.
  734. Functionalization represents a buffer mutation by passing the buffer as
  735. an input and output. For example, consider the eager code:
  736. def forward(self, x):
  737. self.buffer += x
  738. return x * x
  739. This corresponds to a graph that looks like:
  740. def forward(self, buffer, x):
  741. mutated_buffer = aten.add(buffer, x)
  742. mul = aten.mul(x, x)
  743. return (mutated_buffer, mul)
  744. We want to inplace this into something that looks like the original
  745. eager code:
  746. def forward(self, buffer, x):
  747. mutated_buffer = aten.add(buffer, x)
  748. buffer.copy_(mutated_buffer)
  749. mul = aten.mul(x, x)
  750. return (mul,)
  751. Input mutations are handled similarly.
  752. """
  753. output_node = next(iter(reversed(graph.nodes)))
  754. if output_node.op != "output" or len(output_node.args) != 1:
  755. raise AssertionError(
  756. f"expected output node with op='output' and 1 arg, got op={output_node.op!r} with {len(output_node.args)} args"
  757. )
  758. return_args = output_node.args[0]
  759. input_name_to_node = {
  760. node.name: node for node in graph.nodes if node.op == "placeholder"
  761. }
  762. mutation_name_to_input_name = {}
  763. # Collect mutated buffers.
  764. buffer_fqn_to_input_name = {
  765. buffer_fqn: k for k, buffer_fqn in graph_signature.inputs_to_buffers.items()
  766. }
  767. mutation_name_to_input_name = {
  768. k: buffer_fqn_to_input_name[buffer_fqn]
  769. for k, buffer_fqn in graph_signature.buffers_to_mutate.items()
  770. }
  771. # Collect mutated user inputs.
  772. mutation_name_to_input_name.update(graph_signature.user_inputs_to_mutate)
  773. num_mutations = len(mutation_name_to_input_name)
  774. for mutation in return_args[:num_mutations]:
  775. input_name = mutation_name_to_input_name[mutation.name]
  776. input_node = input_name_to_node[input_name]
  777. with graph.inserting_after(mutation):
  778. # Create a copy_ node that inplaces the mutation.
  779. new_node = graph.create_node(
  780. "call_function", torch.ops.aten.copy_.default, (input_node, mutation)
  781. )
  782. for k, v in mutation.meta.items():
  783. new_node.meta[k] = v
  784. # Replace all uses of the previously functional mutation with
  785. # our copy_ node.
  786. mutation.replace_all_uses_with(new_node, lambda x: x is not new_node)
  787. # Remove the mutated buffer / input from the graph outputs, since we don't
  788. # need to thread it through anymore.
  789. user_outputs = tuple(return_args[num_mutations:])
  790. output_node.args = ((user_outputs),)
  791. def _root_module_type(graph: torch.fx.Graph) -> str | None:
  792. for node in graph.nodes:
  793. if "nn_module_stack" not in node.meta:
  794. continue
  795. for path, ty in node.meta["nn_module_stack"].values():
  796. if not path:
  797. return ty
  798. return None
  799. def _fix_nn_module_stacks(graph):
  800. # For each nn module stack in the graph, check if the fqns in it represent a stack:
  801. # 1. Each fqn must be a prefix of the next fqn.
  802. # 2. If not, remove the entries starting from the next fqn, emitting a warning.
  803. for node in graph.nodes:
  804. if "nn_module_stack" not in node.meta:
  805. continue
  806. nn_module_stack = node.meta["nn_module_stack"]
  807. fqns = [
  808. fqn.split("@")[0] if "@" in fqn else fqn
  809. for fqn, _t in nn_module_stack.values()
  810. ]
  811. # Check if each FQN is a prefix of the next one
  812. prev_fqn, *next_fqns = fqns
  813. num_valid_indices = 1 # root FQN
  814. for curr_fqn in next_fqns:
  815. # Check if the previous FQN is a prefix of the current one
  816. if _is_prefix(prev_fqn, curr_fqn):
  817. num_valid_indices += 1
  818. prev_fqn = curr_fqn
  819. else:
  820. # Found a non-prefix FQN, stop here
  821. break
  822. # If we need to remove entries, create a new stack with only valid entries
  823. if num_valid_indices < len(nn_module_stack):
  824. log.warning(
  825. "nn_module_stack fqns %s at node %s do not form a stack! dropping last %d entries",
  826. fqns,
  827. node,
  828. len(nn_module_stack) - num_valid_indices,
  829. )
  830. node.meta["nn_module_stack"] = dict(
  831. list(nn_module_stack.items())[:num_valid_indices]
  832. )
  833. def _is_prefix(candidate, target):
  834. """Check whether `candidate` is a prefix of `target`."""
  835. return len(candidate) < len(target) and target[: len(candidate)] == candidate
  836. def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
  837. if parent_fqn == "":
  838. # Handle the root module correctly.
  839. return child_fqn
  840. parent_split = parent_fqn.split(".")
  841. child_split = child_fqn.split(".")
  842. # TODO: support skip connection by inlining the child module.
  843. if child_split[: len(parent_split)] != parent_split:
  844. raise RuntimeError(
  845. f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'."
  846. "This is currently unsupported."
  847. "Please try to make child module attach to parent module directly."
  848. )
  849. return ".".join(child_split[len(parent_split) :])
  850. def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
  851. def graph_dump(graph: torch.fx.Graph) -> str:
  852. ret = []
  853. nodes_idx: dict[int, int] = {}
  854. def arg_dump(arg) -> str:
  855. if isinstance(arg, torch.fx.Node):
  856. return "%" + str(nodes_idx[id(arg)])
  857. return str(arg)
  858. for i, node in enumerate(graph.nodes):
  859. args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)]
  860. args_dump += [
  861. f"{key}={value}"
  862. for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
  863. ]
  864. target = node.target if node.op in ("call_function", "get_attr") else ""
  865. # pyrefly: ignore [bad-argument-type]
  866. ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
  867. nodes_idx[id(node)] = i
  868. return "\n".join(ret)
  869. if not isinstance(x.graph, torch.fx.Graph):
  870. raise AssertionError(
  871. f"expected x.graph to be torch.fx.Graph, got {type(x.graph)}"
  872. )
  873. if not isinstance(y.graph, torch.fx.Graph):
  874. raise AssertionError(
  875. f"expected y.graph to be torch.fx.Graph, got {type(y.graph)}"
  876. )
  877. return graph_dump(x.graph) == graph_dump(y.graph)
  878. def _add_spec(gm: torch.nn.Module, spec) -> str:
  879. i = 0
  880. while hasattr(gm, f"_spec_{i}"):
  881. i += 1
  882. name = f"_spec_{i}"
  883. setattr(gm, name, spec)
  884. return name
  885. def _generate_flatten(gm: torch.fx.GraphModule, node) -> torch.fx.Node:
  886. flatten = gm.graph.call_function(pytree.tree_flatten, (node,))
  887. getitem_0 = gm.graph.call_function(operator.getitem, (flatten, 0))
  888. return getitem_0
  889. def _generate_flatten_spec(
  890. gm: torch.fx.GraphModule | InterpreterModule | UnflattenedModule, node, spec
  891. ) -> torch.fx.Node:
  892. name = _add_spec(gm, spec)
  893. spec_node = gm.graph.get_attr(name)
  894. return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node))
  895. def _generate_unflatten(
  896. gm: torch.fx.GraphModule | InterpreterModule | UnflattenedModule, nodes, spec
  897. ) -> torch.fx.Node:
  898. name = _add_spec(gm, spec)
  899. spec_node = gm.graph.get_attr(name)
  900. return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node))
  901. def _get_submodule(mod: torch.nn.Module, target: str):
  902. *prefix, field = target.split(".")
  903. for item in prefix:
  904. submod = getattr(mod, item, None)
  905. if submod is None:
  906. return None
  907. if not isinstance(submod, torch.nn.Module):
  908. return None
  909. mod = submod
  910. return getattr(mod, field, None)
  911. def _add_submodule(
  912. mod: torch.nn.Module,
  913. target: str,
  914. module_to_add: torch.nn.Module,
  915. create_module: Callable[[str], torch.nn.Module] | None = None,
  916. ):
  917. *prefix, field = target.split(".")
  918. for i, item in enumerate(prefix):
  919. submod = getattr(mod, item, None)
  920. if submod is None:
  921. if create_module is not None:
  922. submod = create_module(".".join(prefix[: i + 1]))
  923. else:
  924. submod = torch.nn.Module()
  925. setattr(mod, item, submod)
  926. if not isinstance(submod, torch.nn.Module):
  927. return False
  928. mod = submod
  929. mod.add_module(field, module_to_add)
  930. def _call_name(base: str, n: int) -> str:
  931. # Given n >= 0, generate call names to a submodule `base` of the form
  932. # `base`, `base@1`, `base@2`, etc.
  933. return base if n == 1 else f"{base}@{n - 1}"
  934. def _is_call_name(call_name: str, base: str) -> bool:
  935. # Recognize when call_name = _call_name(base, n) for some n >= 0.
  936. return re.match(re.escape(base) + r"(@\d+)?$", call_name) is not None
  937. class _ModuleFrame:
  938. def __init__(
  939. self,
  940. flat_graph: torch.fx.Graph,
  941. nodes: tuple[torch.fx.Node, ...],
  942. seen_nodes,
  943. seen_modules,
  944. seen_attrs,
  945. created_modules,
  946. parent,
  947. module_stack: list[tuple[str, str | None, int]],
  948. module_id,
  949. module_call_graph: dict[str, ModuleCallSignature],
  950. module: torch.fx.GraphModule | UnflattenedModule | None = None,
  951. ):
  952. self.flat_graph = flat_graph
  953. self.nodes = nodes
  954. self.seen_nodes = seen_nodes
  955. self.seen_modules = seen_modules
  956. self.seen_attrs = seen_attrs
  957. self.created_modules = created_modules
  958. self.parent = parent
  959. self.module_stack = module_stack
  960. self.module_id = module_id
  961. self.module_call_graph = module_call_graph
  962. self.verbose = False
  963. self.fqn, ty, num_calls = self.module_stack[-1]
  964. # generate call name for self.fqn
  965. self.child_fqn = _call_name(self.fqn, num_calls + 1)
  966. self.module: torch.fx.GraphModule | UnflattenedModule | InterpreterModule
  967. if module is not None:
  968. self.module = module
  969. self.ivals = module.ivals if hasattr(module, "ivals") else {} # type: ignore[var-annotated]
  970. else:
  971. self.module = self.created_modules.get(
  972. self.fqn,
  973. InterpreterModule(torch.fx.Graph(), ty=ty),
  974. )
  975. self.ivals = parent.ivals
  976. self.graph = self.module.graph
  977. # Mapping of nodes in the flat graph to nodes in this graph.
  978. self.node_map: dict[torch.fx.Node, torch.fx.Node] = {}
  979. self.node_to_placeholder = {}
  980. self.parent_call_module: torch.fx.Node | None = None
  981. if parent is not None:
  982. accessor = _compute_accessor(parent.fqn, self.child_fqn)
  983. def create_module(fqn):
  984. path = f"{parent.fqn}.{fqn}" if parent.fqn else fqn
  985. if path in self.created_modules:
  986. return self.created_modules[path]
  987. submod = InterpreterModule(torch.fx.Graph(), ty=ty)
  988. self.created_modules[path] = submod
  989. return submod
  990. _add_submodule(parent.module, accessor, self.module, create_module)
  991. self.parent_call_module = parent.graph.call_module(accessor)
  992. if self.seen_modules[self.module_id]:
  993. base_module_frame = self.seen_modules[self.module_id][0]
  994. self.module._modules = base_module_frame.module._modules
  995. self.seen_modules[self.module_id].append(
  996. _SubmoduleEntry(
  997. parent_fqn=self.parent.fqn,
  998. parent_module=self.parent.module,
  999. parent_call_module=self.parent_call_module,
  1000. fqn=self.fqn,
  1001. call_idx=num_calls + 1,
  1002. module=self.module,
  1003. )
  1004. )
  1005. signature = module_call_graph.get(self.child_fqn)
  1006. if signature is not None and self.parent is not None:
  1007. if signature.in_spec.num_children != 2:
  1008. raise AssertionError(
  1009. f"expected in_spec to have 2 children, got {signature.in_spec.num_children}"
  1010. )
  1011. if signature.in_spec.type is not tuple:
  1012. raise AssertionError(
  1013. f"expected in_spec.type to be tuple, got {signature.in_spec.type}"
  1014. )
  1015. args_spec, kwargs_spec = signature.in_spec.children()
  1016. if args_spec.type is not tuple:
  1017. raise AssertionError(
  1018. f"expected args_spec.type to be tuple, got {args_spec.type}"
  1019. )
  1020. if kwargs_spec.type is not dict:
  1021. raise AssertionError(
  1022. f"expected kwargs_spec.type to be dict, got {kwargs_spec.type}"
  1023. )
  1024. with self.graph.inserting_after(None):
  1025. arg_nodes = [
  1026. self.graph.placeholder(f"_positional_arg_{idx}")
  1027. for idx in range(args_spec.num_children)
  1028. ]
  1029. kwarg_nodes = {}
  1030. for name in kwargs_spec.context:
  1031. kwarg_nodes[name] = self.graph.placeholder(name)
  1032. flat_args = _generate_flatten_spec(
  1033. self.module,
  1034. (tuple(arg_nodes), kwarg_nodes),
  1035. signature.in_spec,
  1036. )
  1037. for idx, arg in enumerate(signature.inputs):
  1038. flat_arg_node = self.graph.create_node(
  1039. op="call_function",
  1040. target=operator.getitem,
  1041. args=(flat_args, idx),
  1042. name=(
  1043. arg.name
  1044. if not isinstance(arg, ConstantArgument)
  1045. else f"_constant_{idx}"
  1046. ),
  1047. )
  1048. if isinstance(arg, ConstantArgument):
  1049. continue
  1050. if arg.name in self.seen_nodes:
  1051. flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
  1052. self.node_to_placeholder[self.seen_nodes[arg.name]] = (
  1053. flat_arg_node
  1054. )
  1055. with self.parent.graph.inserting_before(self.parent_call_module):
  1056. input_nodes: list[torch.fx.Node | None] = []
  1057. for input in signature.inputs:
  1058. if isinstance(input, ConstantArgument):
  1059. input_nodes.append(input.value) # type: ignore[arg-type]
  1060. elif input.name not in self.seen_nodes:
  1061. input_nodes.append(None)
  1062. else:
  1063. if not isinstance(
  1064. input,
  1065. (
  1066. TensorArgument,
  1067. SymIntArgument,
  1068. SymBoolArgument,
  1069. SymFloatArgument,
  1070. ),
  1071. ):
  1072. raise AssertionError(
  1073. f"expected input to be TensorArgument, SymIntArgument, "
  1074. f"SymBoolArgument, or SymFloatArgument, got {type(input)}"
  1075. )
  1076. input_nodes.append(
  1077. self.parent.remap_input(self.seen_nodes[input.name])
  1078. )
  1079. inputs_node = _generate_unflatten(
  1080. self.parent.module,
  1081. input_nodes,
  1082. signature.in_spec,
  1083. )
  1084. args_node = self.parent.graph.call_function(
  1085. operator.getitem, (inputs_node, 0)
  1086. )
  1087. kwargs_node = self.parent.graph.call_function(
  1088. operator.getitem, (inputs_node, 1)
  1089. )
  1090. arg_nodes = [
  1091. self.parent.graph.call_function(operator.getitem, (args_node, i))
  1092. for i in range(args_spec.num_children)
  1093. ]
  1094. kwarg_nodes = {
  1095. k: self.parent.graph.call_function(
  1096. operator.getitem, (kwargs_node, k)
  1097. )
  1098. for k in kwargs_spec.context
  1099. }
  1100. if self.parent_call_module is None:
  1101. raise AssertionError("parent_call_module must not be None")
  1102. # pyrefly: ignore [bad-assignment]
  1103. self.parent_call_module.args = tuple(arg_nodes)
  1104. self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment]
  1105. def add_placeholder(self, x):
  1106. if self.fqn == "":
  1107. raise AssertionError(f"Cannot add placeholder {x} to root module")
  1108. if x.graph is not self.flat_graph:
  1109. raise AssertionError(
  1110. "expected x.graph to be flat_graph, got different graph"
  1111. ) # noqa: F541
  1112. # x is not in subgraph, create a new placeholder for subgraph
  1113. with self.graph.inserting_before(None):
  1114. placeholder_node = self.graph.placeholder(x.name, type_expr=x.type)
  1115. # copy all meta fields, even if some fields might be irrelevant for
  1116. # the placeholder node
  1117. placeholder_node.meta = copy.copy(x.meta)
  1118. self.node_to_placeholder[x] = placeholder_node
  1119. def copy_sym_call_function(self, x):
  1120. # This only exists because we deduplicate sym_size nodes in the flat export graph,
  1121. # and if preserve_module_call_signature is set, we may not be able to pass sym_size
  1122. # nodes, or their downstream users, as inputs to submodule calls.
  1123. # To avoid this we copy these call_function nodes with sym_type results.
  1124. # This should however only be done for sym_type nodes - call_function nodes on tensors
  1125. # should not be deduplicated in the first place.
  1126. args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args)
  1127. kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs)
  1128. node = self.graph.call_function(x.target, args, kwargs)
  1129. node.meta = copy.copy(x.meta)
  1130. self.node_map[x] = node
  1131. return node
  1132. def remap_input(self, x):
  1133. if x.graph is not self.flat_graph:
  1134. raise AssertionError(
  1135. "expected x.graph to be flat_graph, got different graph"
  1136. ) # noqa: F541
  1137. if x in self.node_map:
  1138. return self.node_map[x]
  1139. self.print(f"remap_input({x})")
  1140. if x in self.node_to_placeholder:
  1141. return self.node_to_placeholder[x]
  1142. elif (
  1143. x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None
  1144. # allow placeholder creation if we are not preserving module call signature
  1145. ):
  1146. self.add_placeholder(x)
  1147. if self.parent_call_module is not None:
  1148. # Important to *prepend* the output to match how we are
  1149. # inserting placeholder nodes.
  1150. with self.parent.graph.inserting_before(self.parent_call_module):
  1151. self.parent_call_module.insert_arg(0, self.parent.remap_input(x))
  1152. return self.node_to_placeholder[x]
  1153. elif x.op == "call_function" and (
  1154. x.target
  1155. in (
  1156. torch.ops.aten.sym_size.int,
  1157. torch.ops.aten.item.default,
  1158. torch.ops.aten.unbind.int,
  1159. torch.ops.aten.sum.dim_IntList,
  1160. torch.ops.aten.view.default,
  1161. torch.ops.aten.diff.default,
  1162. )
  1163. or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator")
  1164. ):
  1165. # export deduplicates sym_size nodes, and may need to re-copy them
  1166. # if module call signature needs to be preserved
  1167. self.copy_sym_call_function(x)
  1168. return self.node_map[x]
  1169. elif self.module_call_graph.get(self.fqn) is not None:
  1170. # x is reading the intermediate value of a mutation, so record it;
  1171. # later we will find where it was created and perform the update
  1172. return self.ivals.read(self, x) # type: ignore[operator, union-attr]
  1173. else:
  1174. raise RuntimeError(
  1175. f"Could not run remap_input() on op type: {x.op} for node {x}"
  1176. )
  1177. def uplift_common_custom_metadata(self) -> None:
  1178. # Copy custom metadata if all nodes have same custom metadata
  1179. custom_meta = None
  1180. for node in self.node_map.values():
  1181. curr_meta = node.meta.get("custom", {})
  1182. if custom_meta is None:
  1183. # first node
  1184. custom_meta = curr_meta
  1185. continue
  1186. if curr_meta != custom_meta:
  1187. custom_meta = {}
  1188. break
  1189. if custom_meta:
  1190. # Lift common custom metadata to parent node and clear children node's custom metadata
  1191. if self.parent_call_module is None:
  1192. raise AssertionError(
  1193. "parent_call_module must not be None when uplifting custom metadata"
  1194. )
  1195. self.parent_call_module.meta["custom"] = custom_meta
  1196. for node in self.node_map.values():
  1197. del node.meta["custom"]
  1198. def finalize_outputs(self):
  1199. self.created_modules.pop(self.fqn, None)
  1200. orig_outputs = []
  1201. signature = self.module_call_graph.get(self.child_fqn)
  1202. if signature is not None and self.parent is not None:
  1203. for output in signature.outputs:
  1204. if isinstance(
  1205. output,
  1206. (
  1207. TensorArgument,
  1208. SymIntArgument,
  1209. SymBoolArgument,
  1210. SymFloatArgument,
  1211. ConstantArgument,
  1212. ),
  1213. ):
  1214. if output.name in self.seen_nodes:
  1215. orig_outputs.append(self.seen_nodes[output.name])
  1216. else:
  1217. orig_outputs.append(None)
  1218. else:
  1219. raise RuntimeError(
  1220. f"Unsupported data type for output node: {output}"
  1221. )
  1222. def get_actual_output_node(output):
  1223. if output is None:
  1224. return None
  1225. seen_node = self.seen_nodes[output.name]
  1226. if seen_node in self.node_map:
  1227. return self.node_map[seen_node]
  1228. elif seen_node in self.node_to_placeholder:
  1229. return self.node_to_placeholder[seen_node]
  1230. else:
  1231. raise RuntimeError(
  1232. f"Could not find output node {output}. Graph: {self.graph}"
  1233. )
  1234. tree_out_node = _generate_unflatten(
  1235. self.module,
  1236. tuple(get_actual_output_node(output) for output in orig_outputs),
  1237. signature.out_spec,
  1238. )
  1239. parent_out: torch.fx.Node | None = _generate_flatten_spec(
  1240. self.parent.module, self.parent_call_module, signature.out_spec
  1241. )
  1242. graph_outputs: torch.fx.Node | list[torch.fx.Node] = tree_out_node
  1243. else:
  1244. graph_outputs = []
  1245. # Iterate through nodes we have copied into self.graph.
  1246. for orig_node in self.node_map:
  1247. for user_node in orig_node.users:
  1248. if user_node.name not in self.seen_nodes:
  1249. # external user node, need to expose as an output
  1250. orig_outputs.append(orig_node)
  1251. graph_outputs.append(self.node_map[orig_node])
  1252. break
  1253. parent_out = self.parent_call_module
  1254. if len(graph_outputs) == 1:
  1255. graph_outputs = graph_outputs[0]
  1256. if not isinstance(graph_outputs, (list, torch.fx.Node)):
  1257. raise AssertionError(
  1258. f"expected graph_outputs to be list or torch.fx.Node, got {type(graph_outputs)}"
  1259. )
  1260. self.graph.output(graph_outputs)
  1261. # Rewrite outputs in parent module
  1262. if parent_out is None:
  1263. return
  1264. parent_out.meta["val"] = (
  1265. graph_outputs.meta.get("val")
  1266. if isinstance(graph_outputs, torch.fx.Node)
  1267. else [o.meta.get("val") for o in graph_outputs]
  1268. )
  1269. self.uplift_common_custom_metadata()
  1270. if len(orig_outputs) == 1 and signature is None:
  1271. self.parent.node_map[orig_outputs[0]] = parent_out
  1272. else:
  1273. for i, orig_output in enumerate(orig_outputs):
  1274. if orig_output is None:
  1275. continue
  1276. # Use Proxy to record getitem access.
  1277. proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index]
  1278. proxy_out.meta["val"] = orig_output.meta.get("val")
  1279. self.parent.node_map[orig_output] = proxy_out
  1280. def copy_node(self, node):
  1281. self.print("copying", node.format_node())
  1282. self.node_map[node] = self.graph.node_copy(node, self.remap_input)
  1283. self.seen_nodes[node.name] = node
  1284. def run_outer(self):
  1285. for i, node in enumerate(self.flat_graph.nodes):
  1286. self.print(i, node.meta.get("nn_module_stack"), node.format_node())
  1287. # Copy all graph inputs
  1288. node_idx: int = 0
  1289. node = self.nodes[node_idx]
  1290. while node.op == "placeholder":
  1291. self.copy_node(node)
  1292. node_idx += 1
  1293. node = self.nodes[node_idx]
  1294. self.run_from(node_idx)
  1295. # Copy graph outputs
  1296. for node in self.flat_graph.nodes:
  1297. if node.op == "output":
  1298. self.copy_node(node)
  1299. def print(self, *args, **kwargs):
  1300. if self.verbose:
  1301. # pyrefly: ignore [not-iterable]
  1302. print(*args, **kwargs)
  1303. def run_from(self, node_idx):
  1304. module_idx = 0
  1305. # Walk through the graph, building up a new graph with the right submodules
  1306. while node_idx < len(self.nodes):
  1307. node = self.nodes[node_idx]
  1308. if node.op == "placeholder":
  1309. raise AssertionError(f"unexpected placeholder node at index {node_idx}")
  1310. self.print()
  1311. self.print("STEP", node_idx, node.format_node())
  1312. self.print(self.module_stack)
  1313. depth = len(self.module_stack)
  1314. if node.op == "output":
  1315. if depth == 1:
  1316. # We want the output node of the original graph to be handled
  1317. # specially by the outermost stack frame (in run_outer). So
  1318. # skip finalization here.
  1319. return node_idx
  1320. # We've reached the end of the graph. Wrap up all the existing stack frames.
  1321. self.finalize_outputs()
  1322. return node_idx
  1323. if len(node.meta.get("nn_module_stack", {})) == 0:
  1324. raise RuntimeError(f"Unable to find nn_module_stack for node {node}")
  1325. nn_module_stack = node.meta["nn_module_stack"]
  1326. from torch._export.passes._node_metadata_hook import (
  1327. _EMPTY_NN_MODULE_STACK_KEY,
  1328. )
  1329. if (
  1330. len(nn_module_stack) == 1
  1331. and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack
  1332. ):
  1333. # Empty case from the node_metadata_hook
  1334. node_module_stack = self.module_stack
  1335. else:
  1336. node_module_stack = [
  1337. (
  1338. path,
  1339. ty if path else None,
  1340. int(k.split("@")[-1]) if "@" in k else 0,
  1341. )
  1342. for k, (path, ty) in node.meta["nn_module_stack"].items()
  1343. ]
  1344. if node_module_stack[:depth] != self.module_stack:
  1345. # This means that the current module is done executing and the
  1346. # current node is the beginning of a new module.
  1347. #
  1348. # In this case, we should finalize this module and return without
  1349. # incrementing the node counter.
  1350. self.finalize_outputs()
  1351. self.print("outlining", self.fqn)
  1352. self.print(self.graph)
  1353. return node_idx
  1354. if node_module_stack is None:
  1355. raise AssertionError("node_module_stack must not be None")
  1356. if _is_prefix(self.module_stack, node_module_stack):
  1357. # This means that the current node represents the execution of a new
  1358. # module.
  1359. next_module = node_module_stack[depth]
  1360. self.print("Creating new stack frame for", next_module)
  1361. # Run a nested version of module outliner from the current node
  1362. # counter. Once it is complete, continue from that point.
  1363. next_module_key = list(node.meta["nn_module_stack"].keys())[depth]
  1364. node_idx = _ModuleFrame(
  1365. self.flat_graph,
  1366. self.nodes,
  1367. self.seen_nodes,
  1368. self.seen_modules,
  1369. self.seen_attrs,
  1370. self.created_modules,
  1371. self,
  1372. self.module_stack + [next_module],
  1373. next_module_key.split("@")[0],
  1374. self.module_call_graph,
  1375. ).run_from(node_idx)
  1376. module_idx += 1
  1377. continue
  1378. # The only remaining possibility is that we are in the right stack
  1379. # frame. Copy the node into this frame's graph and increment the node counter.
  1380. if node_module_stack != self.module_stack:
  1381. raise AssertionError(
  1382. f"expected node_module_stack {node_module_stack} to equal module_stack {self.module_stack}"
  1383. )
  1384. if node.op == "get_attr":
  1385. # this must be a graph argument for a HOP
  1386. self.seen_attrs[self.child_fqn].add(node.target)
  1387. self.copy_node(node)
  1388. # pyrefly: ignore [unsupported-operation]
  1389. node_idx += 1
  1390. @dataclass
  1391. class _SubmoduleEntry:
  1392. parent_fqn: str
  1393. parent_module: torch.nn.Module
  1394. parent_call_module: torch.fx.Node
  1395. fqn: str
  1396. call_idx: int
  1397. module: torch.nn.Module
  1398. def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
  1399. seen_nodes: dict[str, torch.fx.Node] = {}
  1400. seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
  1401. seen_attrs: dict[str, set[str]] = defaultdict(set)
  1402. created_modules: dict[str, torch.nn.Module] = {}
  1403. _ModuleFrame(
  1404. orig_graph,
  1405. tuple(orig_graph.nodes),
  1406. seen_nodes,
  1407. seen_modules,
  1408. seen_attrs,
  1409. created_modules,
  1410. None,
  1411. [("", None, 0)],
  1412. "",
  1413. {
  1414. entry.fqn: entry.signature
  1415. for entry in root_module.module_call_graph
  1416. if entry.signature
  1417. },
  1418. module=root_module,
  1419. ).run_outer()
  1420. return seen_modules, seen_attrs
  1421. def _reorder_submodules(
  1422. parent: torch.nn.Module, fqn_order: dict[str, int], prefix: str = ""
  1423. ):
  1424. # TODO Can be optimized by adding submodules ahead of time.
  1425. if prefix == "":
  1426. for fqn in list(fqn_order.keys())[1:]:
  1427. if _get_submodule(parent, fqn) is None:
  1428. _add_submodule(parent, fqn, torch.nn.Module())
  1429. children = []
  1430. for name, child in list(parent._modules.items()):
  1431. if child is None:
  1432. continue
  1433. fqn = prefix + name
  1434. _reorder_submodules(child, fqn_order, prefix=fqn.split("@")[0] + ".")
  1435. delattr(parent, name)
  1436. children.append((fqn_order[fqn], name, child))
  1437. children.sort(key=operator.itemgetter(0))
  1438. for _, name, child in children:
  1439. parent.register_module(name, child)
  1440. class _IVals:
  1441. """
  1442. Collect the intermediate values of mutations in a graph.
  1443. Example: in the following graph, suppose that buf_in and buf_out
  1444. are the input and output values of a buffer.
  1445. buf_in = placeholder()
  1446. ...
  1447. ival1 = f0(buf_in, ...) # inside self.n0(...)
  1448. ...
  1449. ival2 = f1(ival1, ...) # inside self.n1(...)
  1450. ...
  1451. buf_out = f2(ival2, ...) # inside self.n2(...)
  1452. return buf_out, ...
  1453. Here ival1 and ival2 are intermediate values created inside
  1454. calls to n0 and n1 respectively, and used inside calls to
  1455. n1 and n2 respectively.
  1456. """
  1457. def __init__(self):
  1458. # for each fqn, set of node names corresponding to intermediate values
  1459. self.node_names_by_fqn = defaultdict(set)
  1460. def _is_mutable(self, target):
  1461. if isinstance(target, torch._ops.OpOverload):
  1462. return target._schema.is_mutable
  1463. return False
  1464. def read(self, mf, node):
  1465. """
  1466. Read state corresponding to a given intermediate value.
  1467. """
  1468. # we can assume that the node must be from a mutation
  1469. if node.op != "call_function":
  1470. raise AssertionError(
  1471. f"expected node.op to be 'call_function', got {node.op!r}"
  1472. )
  1473. b = self._is_mutable(node.target)
  1474. print("Checking mutability", node.target, b)
  1475. if not b:
  1476. # so the mutation was functionalized;
  1477. # we will apply the original mutation later (see below)
  1478. fqn, _ = next(reversed(node.meta["nn_module_stack"].values()))
  1479. self.node_names_by_fqn[fqn].add(node.name)
  1480. return mf.remap_input(node.args[0])
  1481. def update(self, partitions):
  1482. """
  1483. Update states corresponding to intermediate values that were read.
  1484. """
  1485. for shared_submodules in partitions:
  1486. for entry in shared_submodules:
  1487. graph = entry.module.graph
  1488. node_names = self.node_names_by_fqn[entry.fqn]
  1489. nodes = [n for n in graph.nodes if n.name in node_names]
  1490. for node in nodes:
  1491. # so node must be from a functionalized mutation;
  1492. # we perform the original mutation now
  1493. with graph.inserting_after(node):
  1494. new_node = graph.create_node(
  1495. "call_function",
  1496. torch.ops.aten.copy_.default,
  1497. (node.args[0], node),
  1498. )
  1499. new_node.meta = copy.copy(node.meta)
  1500. def _copy_graph_attrs(
  1501. gm: torch.fx.GraphModule,
  1502. root_module: UnflattenedModule,
  1503. seen_attrs: dict[str, set[str]],
  1504. ):
  1505. for child_fqn, names in seen_attrs.items():
  1506. module = _get_attr(root_module, child_fqn) if child_fqn else root_module
  1507. for name in names:
  1508. val = getattr(gm, name)
  1509. setattr(module, name, val)
  1510. def _deduplicate_modules(partitions):
  1511. redirected_call_indices = {}
  1512. for shared_submodules in partitions:
  1513. for i, entry in enumerate(shared_submodules):
  1514. child_fqn = _call_name(entry.fqn, entry.call_idx)
  1515. target = _compute_accessor(entry.parent_fqn, child_fqn)
  1516. deduplicated = False
  1517. # Iterate over all previously seen modules, and deduplicate if possible
  1518. for seen in shared_submodules[:i]:
  1519. if _check_graph_equivalence(seen.module, entry.module):
  1520. parent = entry.parent_module
  1521. # Since graphs are equivalent, we can deduplicate.
  1522. # There are two cases.
  1523. if seen.fqn == entry.fqn:
  1524. # Case 1: The current module has the same fqn as the seen module.
  1525. # In this case we have generated a call name that can be optimized away.
  1526. # So we remove the current module from the hierarchy and replace
  1527. # the current call name with the seen call name in the parent graph.
  1528. *prefix, name = target.split(".")
  1529. _get_attr_via_attr_list(parent, prefix)._modules.pop(name)
  1530. seen_child_fqn = _call_name(seen.fqn, seen.call_idx)
  1531. seen_target = _compute_accessor(
  1532. entry.parent_fqn, seen_child_fqn
  1533. )
  1534. entry.parent_call_module.target = seen_target
  1535. redirected_call_indices[child_fqn] = seen_child_fqn
  1536. break
  1537. elif not deduplicated:
  1538. # Case 2: The current module has a different fqn than the seen module.
  1539. # In this case we replace the current module with the seen module.
  1540. # There should be nothing pointing to the current module any more,
  1541. # so it can be garbage collected.
  1542. # NOTE: We *do not* replace the current call name with the seen call name
  1543. # in the parent graph, because this will lose information on which fqn
  1544. # was actually called. However, it is possible that the current call name
  1545. # will be optimized away when we find another seen module with the same fqn,
  1546. # so we do not break out of the loop yet.
  1547. parent.set_submodule(target, seen.module)
  1548. deduplicated = True
  1549. return redirected_call_indices
  1550. def _sink_params(
  1551. module: torch.nn.Module,
  1552. inputs_to_state: dict[str, list[str]],
  1553. scope: list[str],
  1554. module_id_to_inputs_removed: dict[int, set[str]] | None = None,
  1555. ):
  1556. """Sink params, buffers, and constants from graph inputs into get_attr nodes.
  1557. Exported modules are purely functional, so they pass their parameters and
  1558. buffers in as inputs to the graph.
  1559. To replicate eager's semantics, we need to get them from the module state
  1560. via get_attr instead.
  1561. module: GraphModule, potentially containing nested submodules.
  1562. inputs_to_state: mapping graph input names to the corresponding key in the state_dict.
  1563. scope: tracks where we are in the module hierarchy, so that we can emit the
  1564. right `getattr(self, "foo.bar")` calls, etc.
  1565. module_id_to_inputs_removed: records inputs removed by child modules, mapping
  1566. the module object id to the list of placeholder node names in the child module
  1567. that were removed.
  1568. """
  1569. if module_id_to_inputs_removed is None:
  1570. module_id_to_inputs_removed = defaultdict(set)
  1571. if id(module) in module_id_to_inputs_removed:
  1572. return {id(module): module_id_to_inputs_removed[id(module)]}
  1573. # We need to use _modules here instead of named_children(), because we
  1574. # explicitly want duplicate modules to show up in the traversal.
  1575. for name, submodule in module._modules.items():
  1576. submod_id_to_inputs_removed = _sink_params(
  1577. cast("torch.nn.Module", submodule),
  1578. inputs_to_state,
  1579. scope + [name],
  1580. module_id_to_inputs_removed,
  1581. )
  1582. for k, v in submod_id_to_inputs_removed.items():
  1583. module_id_to_inputs_removed[k].update(v)
  1584. graph = getattr(module, "graph", None)
  1585. if graph is None or len(graph.nodes) == 0:
  1586. # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)
  1587. return module_id_to_inputs_removed
  1588. if not isinstance(graph, torch.fx.Graph):
  1589. raise AssertionError(f"expected graph to be torch.fx.Graph, got {type(graph)}")
  1590. inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes))
  1591. the_last_input = None if len(inputs) == 0 else inputs[-1]
  1592. # Also remove from call_module nodes
  1593. call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes)
  1594. for node in call_module_nodes:
  1595. submodule = _get_attr(module, node.target)
  1596. # remove placeholder from call_module node arguments, only if we've
  1597. # erased the placeholder node in the corresponding _sink_params() call
  1598. if submodule is not None and id(submodule) in module_id_to_inputs_removed:
  1599. node.args = tuple(
  1600. filter(
  1601. lambda n: n.name not in module_id_to_inputs_removed[id(submodule)],
  1602. node.args,
  1603. )
  1604. )
  1605. # Filter out inputs_to_state corresponding to current scope.
  1606. inputs_to_state_of_scope: dict[torch.fx.Node, list[str]] = {}
  1607. for node in inputs:
  1608. if node.name not in inputs_to_state:
  1609. continue
  1610. state_name = None
  1611. for sn in inputs_to_state[node.name]:
  1612. sn_split = sn.split(".")
  1613. if sn_split[: len(scope)] == [x.split("@")[0] for x in scope]:
  1614. state_name = sn_split
  1615. break
  1616. # If there's a mismatch between scope name and state name, then
  1617. # there must be multiple scopes pointing to the same state name,
  1618. # meaning some modules are shared. In such case, we can simply skip
  1619. # updating the current node because another later iteration will
  1620. # take care of this input node when the unique match between scope
  1621. # and state name occurs. To make sure this always happen, we should
  1622. # enforce the invariant that no placeholder node in the unflattened
  1623. # graph appears in inputs_to_state dict, which means all the extra
  1624. # input nodes have been handled.
  1625. if state_name is None:
  1626. continue
  1627. inputs_to_state_of_scope[node] = state_name
  1628. # Record name of remove inputs for return purpose.
  1629. inputs_removed: set[str] = set()
  1630. for node, state_name in inputs_to_state_of_scope.items():
  1631. if len(node.users) > 0:
  1632. attr_path = state_name[len(scope) :]
  1633. state_attr = _get_attr_via_attr_list(module, attr_path)
  1634. if not isinstance(state_attr, (torch.Tensor, torch.ScriptObject)):
  1635. raise AssertionError(
  1636. f"expected state_attr to be torch.Tensor or torch.ScriptObject, got {type(state_attr)}"
  1637. )
  1638. # Make sure the newly created get_attr node is placed after the last placeholder node
  1639. with graph.inserting_after(the_last_input):
  1640. new_node = graph.create_node("get_attr", ".".join(attr_path))
  1641. node.replace_all_uses_with(new_node, propagate_meta=True)
  1642. graph.erase_node(node)
  1643. inputs_removed.add(node.name)
  1644. if isinstance(module, InterpreterModule):
  1645. module.finalize()
  1646. return {id(module): inputs_removed}