interpreter.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import logging
  4. from contextlib import contextmanager
  5. from typing import Any, Optional, TYPE_CHECKING, Union
  6. import torch
  7. import torch.fx.traceback as fx_traceback
  8. from torch._logging import LazyString, trace_structured
  9. from torch.hub import tqdm
  10. from . import config
  11. from ._compatibility import compatibility
  12. from ._lazy_graph_module import _make_graph_module
  13. from ._symbolic_trace import Tracer
  14. from .graph import Graph
  15. from .graph_module import GraphModule
  16. from .node import Argument, map_aggregate, map_arg, Node, Target
  17. from .proxy import Proxy
  18. if TYPE_CHECKING:
  19. from collections.abc import Iterator
  20. log = logging.getLogger(__name__)
  21. __all__ = ["Interpreter", "Transformer"]
  22. def _format_fx_node(n):
  23. """
  24. Format a torch.fx.Node into a human-readable string for debug logging.
  25. Args:
  26. n (torch.fx.Node): The FX node being executed.
  27. Returns:
  28. str: A formatted string describing the node operation, including its
  29. name, target, positional arguments, and keyword arguments.
  30. """
  31. module_prefix = getattr(n.target, "__module__", "")
  32. module_prefix = f"{module_prefix}." if module_prefix else ""
  33. # Handle positional and keyword arguments
  34. args = ", ".join(map(str, n.args))
  35. kwargs = ", ".join(f"{k}={v}" for k, v in n.kwargs.items())
  36. joined = ", ".join(filter(None, [args, kwargs]))
  37. return (
  38. f"{n.name} = {module_prefix}{getattr(n.target, '__name__', n.target)}({joined})"
  39. )
  40. @compatibility(is_backward_compatible=True)
  41. class Interpreter:
  42. """
  43. An Interpreter executes an FX graph Node-by-Node. This pattern
  44. can be useful for many things, including writing code
  45. transformations as well as analysis passes.
  46. Methods in the Interpreter class can be overridden to customize
  47. the behavior of execution. The map of overridable methods
  48. in terms of call hierarchy::
  49. run()
  50. +-- run_node
  51. +-- placeholder()
  52. +-- get_attr()
  53. +-- call_function()
  54. +-- call_method()
  55. +-- call_module()
  56. +-- output()
  57. Example:
  58. Suppose we want to swap all instances of ``torch.neg`` with
  59. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  60. method equivalents). We could subclass Interpreter like so::
  61. class NegSigmSwapInterpreter(Interpreter):
  62. def call_function(
  63. self, target: Target, args: Tuple, kwargs: Dict
  64. ) -> Any:
  65. if target is torch.sigmoid:
  66. return torch.neg(*args, **kwargs)
  67. return super().call_function(target, args, kwargs)
  68. def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
  69. if target == "neg":
  70. call_self, *args_tail = args
  71. return call_self.sigmoid(*args_tail, **kwargs)
  72. return super().call_method(target, args, kwargs)
  73. def fn(x):
  74. return torch.sigmoid(x).neg()
  75. gm = torch.fx.symbolic_trace(fn)
  76. input = torch.randn(3, 4)
  77. result = NegSigmSwapInterpreter(gm).run(input)
  78. torch.testing.assert_close(result, torch.neg(input).sigmoid())
  79. Args:
  80. module (torch.nn.Module): The module to be executed
  81. garbage_collect_values (bool): Whether to delete values after their last
  82. use within the Module's execution. This ensures optimal memory usage during
  83. execution. This can be disabled to, for example, examine all of the intermediate
  84. values in the execution by looking at the ``Interpreter.env`` attribute.
  85. graph (Optional[Graph]): If passed, the interpreter will execute this
  86. graph instead of `module.graph`, using the provided `module`
  87. argument to satisfy any requests for state.
  88. """
  89. @compatibility(is_backward_compatible=True)
  90. def __init__(
  91. self,
  92. module: torch.nn.Module,
  93. garbage_collect_values: bool = True,
  94. graph: Optional[Graph] = None,
  95. ):
  96. self.module = module
  97. self.submodules = dict(self.module.named_modules())
  98. if graph is not None:
  99. self.graph = graph
  100. else:
  101. self.graph = self.module.graph # type: ignore[assignment]
  102. self.env: dict[Node, Any] = {}
  103. self.name = "Interpreter"
  104. self.garbage_collect_values = garbage_collect_values
  105. self.extra_traceback = True
  106. if self.garbage_collect_values:
  107. # Run through reverse nodes and record the first instance of a use
  108. # of a given node. This represents the *last* use of the node in the
  109. # execution order of the program, which we will use to free unused
  110. # values
  111. node_to_last_use: dict[Node, Node] = {}
  112. self.user_to_last_uses: dict[Node, list[Node]] = {}
  113. def register_last_uses(n: Node, user: Node):
  114. if n not in node_to_last_use:
  115. node_to_last_use[n] = user
  116. self.user_to_last_uses.setdefault(user, []).append(n)
  117. for node in reversed(self.graph.nodes):
  118. for n in node._input_nodes:
  119. register_last_uses(n, node)
  120. @compatibility(is_backward_compatible=True)
  121. def run(
  122. self,
  123. *args,
  124. initial_env: Optional[dict[Node, Any]] = None,
  125. enable_io_processing: bool = True,
  126. ) -> Any:
  127. """
  128. Run `module` via interpretation and return the result.
  129. Args:
  130. *args: The arguments to the Module to run, in positional order
  131. initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
  132. This is a dict mapping `Node` to any value. This can be used, for example, to
  133. pre-populate results for certain `Nodes` so as to do only partial evaluation within
  134. the interpreter.
  135. enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
  136. process_outputs function first before using them.
  137. Returns:
  138. Any: The value returned from executing the Module
  139. """
  140. self.env = initial_env if initial_env is not None else {}
  141. # Positional function args are consumed left-to-right by
  142. # `placeholder` nodes. Use an iterator to keep track of
  143. # position and extract those values.
  144. if enable_io_processing:
  145. args = self.graph.process_inputs(*args)
  146. self.args_iter: Iterator[Any] = iter(args)
  147. pbar = tqdm(
  148. total=len(self.graph.nodes),
  149. desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
  150. initial=0,
  151. position=0,
  152. leave=True,
  153. disable=config.disable_progress,
  154. delay=0,
  155. )
  156. for node in self.graph.nodes:
  157. pbar.update(1)
  158. if node in self.env:
  159. # Short circuit if we have this value. This could
  160. # be used, for example, for partial evaluation
  161. # where the caller has pre-populated `env` with
  162. # values for a subset of the program.
  163. continue
  164. try:
  165. self.env[node] = self.run_node(node)
  166. except Exception as e:
  167. if self.extra_traceback:
  168. msg = f"While executing {node.format_node()}"
  169. msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg)
  170. msg += f"\nOriginal traceback:\n{node.stack_trace}"
  171. if (
  172. isinstance(self.module, GraphModule)
  173. and self.module.graph is not None
  174. and isinstance(self.module.graph, torch.fx.Graph)
  175. ):
  176. trace_structured(
  177. "artifact",
  178. metadata_fn=lambda: {
  179. "name": "fx_interpreter_error",
  180. "encoding": "string",
  181. },
  182. payload_fn=lambda: (
  183. f"{msg}\nGraphModule: "
  184. f"{self.module.print_readable(print_output=False, include_stride=True)}" # type: ignore[operator]
  185. ),
  186. )
  187. msg += "\nUse tlparse to see full graph. "
  188. msg += "(https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)"
  189. e.args = (msg,) + e.args[1:]
  190. if isinstance(e, KeyError):
  191. raise RuntimeError(*e.args) from e
  192. raise
  193. if self.garbage_collect_values:
  194. for to_delete in self.user_to_last_uses.get(node, []):
  195. del self.env[to_delete]
  196. if node.op == "output":
  197. output_val = self.env[node]
  198. return (
  199. self.graph.process_outputs(output_val)
  200. if enable_io_processing
  201. else output_val
  202. )
  203. @compatibility(is_backward_compatible=True)
  204. def boxed_run(self, args_list):
  205. """
  206. Run `module` via interpretation and return the result. This uses the "boxed"
  207. calling convention, where you pass a list of arguments, which will be cleared
  208. by the interpreter. This ensures that input tensors are promptly deallocated.
  209. """
  210. # Collect placeholder nodes first
  211. placeholder_nodes = [n for n in self.graph.nodes if n.op == "placeholder"]
  212. # Check argument count
  213. if len(args_list) != len(placeholder_nodes):
  214. detail = (
  215. "extra arguments"
  216. if len(args_list) > len(placeholder_nodes)
  217. else "missing arguments"
  218. )
  219. raise RuntimeError(
  220. f"Interpreter.boxed_run expected {len(placeholder_nodes)} arguments for placeholders "
  221. f"but received {len(args_list)} ({detail})"
  222. )
  223. # Assign arguments to placeholders
  224. env = dict(zip(placeholder_nodes, args_list))
  225. args_list.clear()
  226. return self.run(initial_env=env)
  227. @contextmanager
  228. def _set_current_node(self, node):
  229. with fx_traceback.set_current_meta(
  230. node, f"Interpreter_{self.__class__.__name__}"
  231. ):
  232. yield
  233. @compatibility(is_backward_compatible=True)
  234. def run_node(self, n: Node) -> Any:
  235. """
  236. Run a specific node ``n`` and return the result.
  237. Calls into placeholder, get_attr, call_function,
  238. call_method, call_module, or output depending
  239. on ``node.op``
  240. Args:
  241. n (Node): The Node to execute
  242. Returns:
  243. Any: The result of executing ``n``
  244. """
  245. log.debug("run_node %s", LazyString(lambda: _format_fx_node(n)))
  246. with self._set_current_node(n):
  247. args, kwargs = self.fetch_args_kwargs_from_env(n)
  248. if not isinstance(args, tuple):
  249. raise AssertionError(f"Expected args to be tuple, got {type(args)}")
  250. if not isinstance(kwargs, dict):
  251. raise AssertionError(f"Expected kwargs to be dict, got {type(kwargs)}")
  252. return getattr(self, n.op)(n.target, args, kwargs)
  253. # Main Node running APIs
  254. @compatibility(is_backward_compatible=True)
  255. def placeholder(
  256. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  257. ) -> Any:
  258. """
  259. Execute a ``placeholder`` node. Note that this is stateful:
  260. ``Interpreter`` maintains an internal iterator over
  261. arguments passed to ``run`` and this method returns
  262. next() on that iterator.
  263. Args:
  264. target (Target): The call target for this node. See
  265. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  266. details on semantics
  267. args (Tuple): Tuple of positional args for this invocation
  268. kwargs (Dict): Dict of keyword arguments for this invocation
  269. Returns:
  270. Any: The argument value that was retrieved.
  271. """
  272. if not isinstance(target, str):
  273. raise AssertionError(f"Expected target to be str, got {type(target)}")
  274. if target.startswith("*"):
  275. # For a starred parameter e.g. `*args`, retrieve all
  276. # remaining values from the args list.
  277. return list(self.args_iter)
  278. else:
  279. try:
  280. return next(self.args_iter)
  281. except StopIteration as si:
  282. if len(args) > 0:
  283. return args[0]
  284. else:
  285. raise RuntimeError(
  286. f"Expected positional argument for parameter {target}, but one was not passed in!"
  287. ) from si
  288. @compatibility(is_backward_compatible=True)
  289. def get_attr(
  290. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  291. ) -> Any:
  292. """
  293. Execute a ``get_attr`` node. Will retrieve an attribute
  294. value from the ``Module`` hierarchy of ``self.module``.
  295. Args:
  296. target (Target): The call target for this node. See
  297. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  298. details on semantics
  299. args (Tuple): Tuple of positional args for this invocation
  300. kwargs (Dict): Dict of keyword arguments for this invocation
  301. Return:
  302. Any: The value of the attribute that was retrieved
  303. """
  304. if not isinstance(target, str):
  305. raise AssertionError(f"Expected target to be str, got {type(target)}")
  306. return self.fetch_attr(target)
  307. @compatibility(is_backward_compatible=True)
  308. def call_function(
  309. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  310. ) -> Any:
  311. """
  312. Execute a ``call_function`` node and return the result.
  313. Args:
  314. target (Target): The call target for this node. See
  315. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  316. details on semantics
  317. args (Tuple): Tuple of positional args for this invocation
  318. kwargs (Dict): Dict of keyword arguments for this invocation
  319. Return
  320. Any: The value returned by the function invocation
  321. """
  322. if isinstance(target, str):
  323. raise AssertionError("target should not be a string for call_function")
  324. # Execute the function and return the result
  325. return target(*args, **kwargs)
  326. @compatibility(is_backward_compatible=True)
  327. def call_method(
  328. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  329. ) -> Any:
  330. """
  331. Execute a ``call_method`` node and return the result.
  332. Args:
  333. target (Target): The call target for this node. See
  334. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  335. details on semantics
  336. args (Tuple): Tuple of positional args for this invocation
  337. kwargs (Dict): Dict of keyword arguments for this invocation
  338. Return
  339. Any: The value returned by the method invocation
  340. """
  341. # args[0] is the `self` object for this method call
  342. self_obj, *args_tail = args
  343. # Execute the method and return the result
  344. if not isinstance(target, str):
  345. raise AssertionError(f"Expected target to be str, got {type(target)}")
  346. return getattr(self_obj, target)(*args_tail, **kwargs)
  347. @compatibility(is_backward_compatible=True)
  348. def call_module(
  349. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  350. ) -> Any:
  351. """
  352. Execute a ``call_module`` node and return the result.
  353. Args:
  354. target (Target): The call target for this node. See
  355. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  356. details on semantics
  357. args (Tuple): Tuple of positional args for this invocation
  358. kwargs (Dict): Dict of keyword arguments for this invocation
  359. Return
  360. Any: The value returned by the module invocation
  361. """
  362. # Retrieve executed args and kwargs values from the environment
  363. # Execute the method and return the result
  364. if not isinstance(target, str):
  365. raise AssertionError(f"Expected target to be str, got {type(target)}")
  366. submod = self.fetch_attr(target)
  367. return submod(*args, **kwargs)
  368. @compatibility(is_backward_compatible=True)
  369. def output(
  370. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  371. ) -> Any:
  372. """
  373. Execute an ``output`` node. This really just retrieves
  374. the value referenced by the ``output`` node and returns it.
  375. Args:
  376. target (Target): The call target for this node. See
  377. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  378. details on semantics
  379. args (Tuple): Tuple of positional args for this invocation
  380. kwargs (Dict): Dict of keyword arguments for this invocation
  381. Return:
  382. Any: The return value referenced by the output node
  383. """
  384. return args[0]
  385. # Helper methods
  386. @compatibility(is_backward_compatible=True)
  387. def fetch_attr(self, target: str):
  388. """
  389. Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
  390. Args:
  391. target (str): The fully-qualified name of the attribute to fetch
  392. Return:
  393. Any: The value of the attribute.
  394. """
  395. target_atoms = target.split(".")
  396. attr_itr = self.module
  397. for i, atom in enumerate(target_atoms):
  398. if not hasattr(attr_itr, atom):
  399. raise RuntimeError(
  400. f"Node referenced nonexistent target {'.'.join(target_atoms[: i + 1])}"
  401. )
  402. attr_itr = getattr(attr_itr, atom)
  403. return attr_itr
  404. @compatibility(is_backward_compatible=True)
  405. def fetch_args_kwargs_from_env(self, n: Node) -> tuple[tuple, dict]:
  406. """
  407. Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
  408. from the current execution environment.
  409. Args:
  410. n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
  411. Return:
  412. Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
  413. """
  414. args = self.map_nodes_to_values(n.args, n)
  415. if not isinstance(args, tuple):
  416. raise AssertionError(f"Expected args to be tuple, got {type(args)}")
  417. kwargs = self.map_nodes_to_values(n.kwargs, n)
  418. if not isinstance(kwargs, dict):
  419. raise AssertionError(f"Expected kwargs to be dict, got {type(kwargs)}")
  420. return args, kwargs
  421. @compatibility(is_backward_compatible=True)
  422. def map_nodes_to_values(self, args: Argument, n: Node) -> Argument:
  423. """
  424. Recursively descend through ``args`` and look up the concrete value
  425. for each ``Node`` in the current execution environment.
  426. Args:
  427. args (Argument): Data structure within which to look up concrete values
  428. n (Node): Node to which ``args`` belongs. This is only used for error reporting.
  429. """
  430. def load_arg(n_arg: Node) -> Any:
  431. if n_arg not in self.env:
  432. raise RuntimeError(
  433. f"Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() "
  434. f"to diagnose such issues"
  435. )
  436. return self.env[n_arg]
  437. return map_arg(args, load_arg)
  438. @compatibility(is_backward_compatible=True)
  439. class Transformer(Interpreter):
  440. """
  441. ``Transformer`` is a special type of interpreter that produces a
  442. new ``Module``. It exposes a ``transform()`` method that returns
  443. the transformed ``Module``. ``Transformer`` does not require
  444. arguments to run, as ``Interpreter`` does. ``Transformer`` works
  445. entirely symbolically.
  446. Example:
  447. Suppose we want to swap all instances of ``torch.neg`` with
  448. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  449. method equivalents). We could subclass ``Transformer`` like so::
  450. class NegSigmSwapXformer(Transformer):
  451. def call_function(
  452. self,
  453. target: "Target",
  454. args: Tuple[Argument, ...],
  455. kwargs: Dict[str, Any],
  456. ) -> Any:
  457. if target is torch.sigmoid:
  458. return torch.neg(*args, **kwargs)
  459. return super().call_function(target, args, kwargs)
  460. def call_method(
  461. self,
  462. target: "Target",
  463. args: Tuple[Argument, ...],
  464. kwargs: Dict[str, Any],
  465. ) -> Any:
  466. if target == "neg":
  467. call_self, *args_tail = args
  468. return call_self.sigmoid(*args_tail, **kwargs)
  469. return super().call_method(target, args, kwargs)
  470. def fn(x):
  471. return torch.sigmoid(x).neg()
  472. gm = torch.fx.symbolic_trace(fn)
  473. transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
  474. input = torch.randn(3, 4)
  475. torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
  476. Args:
  477. module (GraphModule): The ``Module`` to be transformed.
  478. """
  479. @compatibility(is_backward_compatible=True)
  480. def __init__(self, module):
  481. super().__init__(module)
  482. self.new_graph = Graph()
  483. self.new_graph.set_codegen(module.graph._codegen)
  484. class TransformerTracer(Tracer):
  485. def __init__(self, graph: Graph):
  486. super().__init__()
  487. self.graph = graph
  488. self.tensor_attrs: dict[torch.Tensor, str] = {} # type: ignore[assignment]
  489. def is_leaf_module(self, _, __) -> bool:
  490. return True
  491. self.tracer = TransformerTracer(self.new_graph)
  492. self.tracer.root = module
  493. @compatibility(is_backward_compatible=True)
  494. def placeholder(
  495. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  496. ) -> Proxy:
  497. """
  498. Execute a ``placeholder`` node. In ``Transformer``, this is
  499. overridden to insert a new ``placeholder`` into the output
  500. graph.
  501. Args:
  502. target (Target): The call target for this node. See
  503. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  504. details on semantics
  505. args (Tuple): Tuple of positional args for this invocation
  506. kwargs (Dict): Dict of keyword arguments for this invocation
  507. """
  508. if not isinstance(target, str):
  509. raise AssertionError(f"Expected target to be str, got {type(target)}")
  510. default_value = next(iter(args)) if args else inspect.Signature.empty
  511. return Proxy(
  512. self.new_graph.placeholder(target, default_value=default_value), self.tracer
  513. )
  514. @compatibility(is_backward_compatible=True)
  515. def get_attr(
  516. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  517. ) -> Proxy:
  518. """
  519. Execute a ``get_attr`` node. In ``Transformer``, this is
  520. overridden to insert a new ``get_attr`` node into the output
  521. graph.
  522. Args:
  523. target (Target): The call target for this node. See
  524. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  525. details on semantics
  526. args (Tuple): Tuple of positional args for this invocation
  527. kwargs (Dict): Dict of keyword arguments for this invocation
  528. """
  529. if not isinstance(target, str):
  530. raise AssertionError(f"Expected target to be str, got {type(target)}")
  531. return self.tracer.create_proxy("get_attr", target, args, kwargs)
  532. @compatibility(is_backward_compatible=True)
  533. def call_module(
  534. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  535. ) -> Any:
  536. # Override so that the leaf module policy from `self.tracer` is respected.
  537. if not isinstance(target, str):
  538. raise AssertionError(f"Expected target to be str, got {type(target)}")
  539. submod = self.fetch_attr(target)
  540. return self.tracer.call_module(submod, submod.forward, args, kwargs)
  541. @compatibility(is_backward_compatible=True)
  542. def call_function(
  543. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  544. ) -> Any:
  545. # Override so that functions that were wrapped are still wrapped.
  546. return self.tracer.create_proxy("call_function", target, args, kwargs)
  547. @compatibility(is_backward_compatible=True)
  548. def transform(self) -> GraphModule:
  549. """
  550. Transform ``self.module`` and return the transformed
  551. ``GraphModule``.
  552. """
  553. with fx_traceback.preserve_node_meta():
  554. result = super().run(enable_io_processing=False)
  555. if result is not None:
  556. def strip_proxy(a: Union[Argument, Proxy]) -> Any:
  557. return a.node if isinstance(a, Proxy) else a
  558. new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy))
  559. # also preserve the metadata from the old output node, if it exists
  560. old_output_node = list(self.graph.nodes)[-1]
  561. if old_output_node.op != "output":
  562. raise AssertionError(
  563. f"Expected output node, got op={old_output_node.op}"
  564. )
  565. for k, v in old_output_node.meta.items():
  566. new_output_node.meta[k] = v
  567. return _make_graph_module(self.module, self.new_graph)