_swap.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. import logging
  2. import operator
  3. import types
  4. from collections import defaultdict
  5. import torch
  6. import torch.fx._pytree as fx_pytree
  7. import torch.utils._pytree as pytree
  8. from torch.export.exported_program import (
  9. ConstantArgument,
  10. ExportedProgram,
  11. ModuleCallSignature,
  12. )
  13. from torch.fx.passes.tools_common import legalize_graph, NodeList
  14. from torch.fx.passes.utils.fuser_utils import erase_nodes, fuse_as_graphmodule
  15. log = logging.getLogger(__name__)
  16. def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]:
  17. node_users = list(node.users.keys())
  18. getitem_users = set()
  19. for user in node_users:
  20. if user.op == "output":
  21. continue
  22. if not (user.op == "call_function" and user.target is operator.getitem):
  23. raise AssertionError(
  24. f"Expected getitem node as user for {node}, instead got {user}"
  25. )
  26. getitem_users.update(list(user.users.keys()))
  27. return getitem_users
  28. def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
  29. """
  30. We want to try to remove extraneous pytree flatten/unflatten calls between modules
  31. calls. Instead of having the following:
  32. graph():
  33. ...
  34. %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
  35. %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {})
  36. %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
  37. %tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
  38. %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
  39. %getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {})
  40. %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})
  41. %bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {})
  42. ...
  43. We could do the following, if we know that all the outputs of `foo` feed into `bar`:
  44. graph():
  45. ...
  46. %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
  47. %bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {})
  48. ...
  49. Currently this optimization only works for the case where all of the outputs
  50. of `foo` go directly into `bar`, and `bar` has no other inputs.
  51. """ # noqa: B950
  52. log.debug("Trying to remove pytrees for module call %s", curr_module_node)
  53. curr_module_users = list(curr_module_node.users.keys())
  54. if len(curr_module_users) != 1:
  55. raise AssertionError(
  56. f"Expected only one user for module node, instead got {list(curr_module_users)}"
  57. )
  58. flatten_node = curr_module_users[0]
  59. if not (
  60. flatten_node.op == "call_function"
  61. and flatten_node.target is fx_pytree.tree_flatten_spec
  62. ):
  63. raise AssertionError(
  64. f"Expected flatten_node to be a call_function with target tree_flatten_spec, "
  65. f"but got op={flatten_node.op}, target={flatten_node.target}"
  66. )
  67. flatten_getitem_users = _get_getitem_users(flatten_node)
  68. if len(flatten_getitem_users) != 1:
  69. log.debug(
  70. "More than one user found for flatten node, %s: %s. "
  71. "Unable to fuse it with another unflatten call.",
  72. flatten_node,
  73. flatten_getitem_users,
  74. )
  75. return
  76. unflatten_node = next(iter(flatten_getitem_users))
  77. if not (
  78. unflatten_node.op == "call_function"
  79. and unflatten_node.target is pytree.tree_unflatten
  80. ):
  81. log.debug(
  82. "Flatten node %s's user is not a pytree.tree_unflatten. "
  83. "Instead it is: %s. Passing...",
  84. flatten_node,
  85. unflatten_node,
  86. )
  87. return
  88. for i, arg in enumerate(unflatten_node.args[0]): # type: ignore[union-attr,arg-type]
  89. if arg not in flatten_node.users:
  90. log.debug(
  91. "Module %s's outputs are not all directly used as inputs to "
  92. "the subsequent module. Unable to fuse the connecting "
  93. "flatten/unflatten. The inputs to the subsequent module are: %s. ",
  94. curr_module_node,
  95. unflatten_node.args[0],
  96. )
  97. return
  98. if not (
  99. # pyrefly: ignore [missing-attribute]
  100. arg.op == "call_function"
  101. # pyrefly: ignore [missing-attribute]
  102. and arg.target is operator.getitem
  103. # pyrefly: ignore [missing-attribute]
  104. and arg.args[1] == i
  105. ):
  106. log.debug(
  107. "Module %s's outputs are not all directly used in the same "
  108. "order as outputted. Unable to fuse the connecting "
  109. "flatten/unflatten. The inputs to the "
  110. "subsequent module are: %s. ",
  111. curr_module_node,
  112. unflatten_node.args[0],
  113. )
  114. return
  115. # Unflatten has two levels of getitem, because it gets the args and kwargs
  116. unflatten_getitem_getitem_users = set()
  117. unflatten_getitem_users = _get_getitem_users(unflatten_node)
  118. for unflatten_getitem_user in unflatten_getitem_users:
  119. unflatten_getitem_getitem_users.update(
  120. list(unflatten_getitem_user.users.keys())
  121. )
  122. if len(unflatten_getitem_getitem_users) != 1:
  123. log.debug(
  124. "More than one user found for unflatten node, %s: %s. "
  125. "Unable to fuse it with another flatten call.",
  126. unflatten_node,
  127. unflatten_getitem_getitem_users,
  128. )
  129. return
  130. next_module_node = next(iter(unflatten_getitem_getitem_users))
  131. if next_module_node.op != "call_module":
  132. log.debug(
  133. "Unflatten node %s's user is not a call_module. "
  134. "Instead it is: %s. Passing...",
  135. unflatten_node,
  136. next_module_node,
  137. )
  138. return
  139. # Directly put the outputs of the current module into the next module
  140. next_module_node.args = (curr_module_node,)
  141. def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None:
  142. """
  143. Remove extraneous pytree flatten/unflatten calls.
  144. We try a couple of optimizations here:
  145. 1. Remove pytree flatten/unflatten calls between modules
  146. 2. TODO: Remove module's in_spec + initial unflatten call
  147. 3. TODO: Remove module's out_spec + final flatten call
  148. """
  149. for node in gm.graph.nodes:
  150. if node.op == "call_module" and node.target != "_guards_fn":
  151. _try_remove_connecting_pytrees(node)
  152. gm.graph.eliminate_dead_code()
  153. def _construct_inputs(
  154. gm: torch.fx.GraphModule,
  155. signature: ModuleCallSignature,
  156. node_name_map: dict[str, torch.fx.Node],
  157. ) -> tuple[list[torch.fx.Node], dict[str, torch.fx.Node]]:
  158. tree_unflatten_args: list[torch.fx.Node | None] = []
  159. for input_ in signature.inputs:
  160. if isinstance(input_, ConstantArgument) and input_.value is None:
  161. # Constants should be directly embedded into the graph and not used
  162. # as inputs
  163. tree_unflatten_args.append(None)
  164. elif input_.name not in node_name_map:
  165. # For unused inputs
  166. tree_unflatten_args.append(None)
  167. else:
  168. tree_unflatten_args.append(node_name_map[input_.name])
  169. # Insert unflatten call
  170. from .unflatten import _generate_unflatten
  171. unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec)
  172. if signature.in_spec.num_children != 2:
  173. raise AssertionError(
  174. f"Expected in_spec to have 2 children, but got {signature.in_spec.num_children}"
  175. )
  176. if signature.in_spec.type is not tuple:
  177. raise AssertionError(
  178. f"Expected in_spec type to be tuple, but got {signature.in_spec.type}"
  179. )
  180. args_spec, kwargs_spec = signature.in_spec.children()
  181. if args_spec.type is not tuple:
  182. raise AssertionError(
  183. f"Expected args_spec type to be tuple, but got {args_spec.type}"
  184. )
  185. if kwargs_spec.type is not dict:
  186. raise AssertionError(
  187. f"Expected kwargs_spec type to be dict, but got {kwargs_spec.type}"
  188. )
  189. args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0))
  190. args_nodes = [
  191. gm.graph.call_function(operator.getitem, (args_node, i))
  192. for i in range(args_spec.num_children)
  193. ]
  194. kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1))
  195. kwargs_nodes = {
  196. k: gm.graph.call_function(operator.getitem, (kwargs_node, k))
  197. for k in kwargs_spec.context
  198. }
  199. return args_nodes, kwargs_nodes
  200. def _insert_call_module(
  201. gm: torch.fx.GraphModule,
  202. args_nodes: list[torch.fx.Node],
  203. kwargs_nodes: dict[str, torch.fx.Node],
  204. module_to_swap: torch.nn.Module,
  205. name: str,
  206. ) -> torch.fx.Node:
  207. from .unflatten import _assign_attr, _AttrKind
  208. _assign_attr(module_to_swap, gm, name, _AttrKind.MODULE)
  209. module_node = gm.graph.call_module(name, tuple(args_nodes), kwargs_nodes) # type: ignore[arg-type]
  210. return module_node
  211. def _deconstruct_outputs(
  212. gm: torch.fx.GraphModule,
  213. signature: ModuleCallSignature,
  214. module_node: torch.fx.Node,
  215. node_name_map: dict[str, torch.fx.Node],
  216. orig_outputs: tuple[torch.fx.Node, ...],
  217. ) -> None:
  218. from .unflatten import _generate_flatten_spec
  219. flatten_node = _generate_flatten_spec(gm, module_node, signature.out_spec)
  220. for i, orig_output in enumerate(orig_outputs):
  221. # Use Proxy to record getitem access.
  222. proxy_out = torch.fx.Proxy(flatten_node)[i].node # type: ignore[index]
  223. orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
  224. node_name_map[orig_output.name] = proxy_out
  225. def _swap_module_helper(
  226. gm: torch.fx.GraphModule,
  227. modules_to_swap: dict[str, torch.nn.Module],
  228. module_call_graph: dict[str, ModuleCallSignature],
  229. ) -> torch.fx.GraphModule:
  230. log.debug("Starting graph:")
  231. log.debug(gm.graph)
  232. legalize_graph(gm)
  233. partitions: dict[str, NodeList] = defaultdict(list)
  234. node_name_map: dict[str, torch.fx.Node] = {
  235. node.name: node for node in gm.graph.nodes
  236. }
  237. # TODO: Handle the duplicate module case
  238. for node in gm.graph.nodes:
  239. if nn_module_stack := node.meta.get("nn_module_stack"):
  240. for path, _ in nn_module_stack.values():
  241. if path in modules_to_swap:
  242. partitions[path].append(node)
  243. break
  244. for name, nodes in partitions.items():
  245. """
  246. Given a graph like the following, and we want to swap out the submodule "foo":
  247. graph():
  248. %x : [num_users=1] = placeholder[target=x]
  249. %y : [num_users=2] = placeholder[target=y]
  250. %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {}), nn_module_stack = {"foo": ("foo", torch.nn.Module)}
  251. %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %add), kwargs = {}), nn_module_stack = {"bar": ("bar", torch.nn.Module)}
  252. return (sub,)
  253. We will first partition out foo's subgraph:
  254. graph():
  255. %x : [num_users=1] = placeholder[target=x]
  256. %y : [num_users=2] = placeholder[target=y]
  257. %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {})
  258. return add
  259. And then insert an unflatten + call_module + flatten to replace the subgraph:
  260. graph():
  261. %x : [num_users=1] = placeholder[target=x]
  262. %y : [num_users=1] = placeholder[target=y]
  263. %_spec_0 : [num_users=1] = get_attr[target=_spec_0]
  264. %tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
  265. %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {})
  266. %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
  267. %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {})
  268. %getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten, 1), kwargs = {})
  269. %foo : [num_users=0] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
  270. %_spec_1 : [num_users=1] = get_attr[target=_spec_1]
  271. %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (None, %_spec_1), kwargs = {})
  272. %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
  273. %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %getitem_4), kwargs = {})
  274. return (%sub,)
  275. The `tree_unflatten` call will construct tensor inputs into the input
  276. format needed by the swapped eager module.
  277. The `call_module` node should now reference the swapped torch.nn.Module.
  278. The `tree_flatten_spec` call will deconstruct the eager outputs of the
  279. swapped module into tensors.
  280. """ # noqa: B950
  281. submod_name = name.replace(".", "_")
  282. sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
  283. gm, nodes, f"fused_{submod_name}"
  284. )
  285. log.debug("Fused subgraph nodes:")
  286. log.debug(sub_gm.graph)
  287. signature: ModuleCallSignature = module_call_graph[name]
  288. args_nodes, kwargs_nodes = _construct_inputs(gm, signature, node_name_map)
  289. module_node = _insert_call_module(
  290. gm, args_nodes, kwargs_nodes, modules_to_swap[name], name
  291. )
  292. _deconstruct_outputs(gm, signature, module_node, node_name_map, orig_outputs)
  293. erase_nodes(gm, nodes)
  294. log.debug("Swapped graph:")
  295. log.debug(gm.graph)
  296. legalize_graph(gm)
  297. log.debug("Before removing extraneous pytrees:")
  298. log.debug(gm.graph)
  299. _remove_extraneous_pytrees(gm)
  300. log.debug("After removing extraneous pytrees:")
  301. log.debug(gm.graph)
  302. gm.recompile()
  303. return gm
  304. def _fix_input_output_signature(
  305. gm: torch.fx.GraphModule, signature: ModuleCallSignature
  306. ) -> None:
  307. """
  308. Given the unlifted module from calling ep.module(), we want to remove the
  309. pytree processing from the graph module's PyTreeCodeGen and instead make it
  310. nodes inside of the graph. This allows us to do some optimizations, like
  311. remove these pytree calls if it is unnecessary, and makes the PyTree part
  312. more obvious to graph passes.
  313. """
  314. from torch.export.unflatten import _generate_flatten, _generate_unflatten
  315. # Remove the registered pytree codegen because we will take care of it
  316. # through inserting pytree nodes into the graph
  317. gm.graph._codegen = torch.fx.graph.CodeGen()
  318. old_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
  319. new_placeholders = []
  320. forward_arg_names = signature.forward_arg_names
  321. if forward_arg_names is None:
  322. forward_arg_names = []
  323. if signature.in_spec.num_children != 2:
  324. raise AssertionError(
  325. f"Expected in_spec to have 2 children, but got {signature.in_spec.num_children}"
  326. )
  327. arg_spec = signature.in_spec.child(0)
  328. kwarg_spec = signature.in_spec.child(1)
  329. if arg_spec.type is not tuple:
  330. raise AssertionError(
  331. f"Expected arg_spec type to be tuple, but got {arg_spec.type}"
  332. )
  333. if kwarg_spec.type is not dict:
  334. raise AssertionError(
  335. f"Expected kwarg_spec type to be dict, but got {kwarg_spec.type}"
  336. )
  337. for i in range(arg_spec.num_children):
  338. forward_arg_names.append(f"arg_{i}")
  339. forward_arg_names.extend(kwarg_spec.context)
  340. for arg in forward_arg_names:
  341. with gm.graph.inserting_before(old_placeholders[0]):
  342. new_placeholders.append(gm.graph.placeholder(arg))
  343. # Insert flatten call for the inputs
  344. with gm.graph.inserting_before(old_placeholders[0]):
  345. flat_node = _generate_flatten(gm, tuple(new_placeholders))
  346. for i, old_placeholder in enumerate(old_placeholders):
  347. old_placeholder.op = "call_function"
  348. old_placeholder.target = operator.getitem
  349. old_placeholder.args = (flat_node, i)
  350. # Insert unflatten call for the outputs
  351. output_node = next(node for node in gm.graph.nodes if node.op == "output")
  352. with gm.graph.inserting_before(output_node):
  353. unflat = _generate_unflatten(gm, output_node.args[0], signature.out_spec)
  354. output_node.args = (unflat,)
  355. gm.recompile()
  356. def _swap_modules(
  357. ep: ExportedProgram, modules_to_swap: dict[str, torch.nn.Module]
  358. ) -> torch.fx.GraphModule:
  359. """
  360. Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps
  361. previously traced modules with new eager modules specified. Returns a
  362. fx.GraphModule with a custom forward function.
  363. Args:
  364. ep (ExportedProgram): Exported program to modify
  365. modules_to_swap (Dict[str, torch.nn.Module]): Mapping from module fqn to
  366. eager module to swap with. The specified module fqn should have also
  367. been specified in the `preserve_module_call_signature` argument to
  368. torch.export so that we know how to restore the calling convention
  369. to this argument.
  370. run_with_interpreter: Whether or not to run the graph using
  371. fx.Interpreter. Setting to true will help result in better error
  372. messages and easier debugging, but it has found to result in a QPS
  373. drop.
  374. """
  375. module_call_graph = {
  376. entry.fqn: entry.signature for entry in ep.module_call_graph if entry.signature
  377. }
  378. gm = ep.module()
  379. gm.validate_inputs = False # type: ignore[assignment]
  380. gm.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
  381. if not isinstance(gm, torch.fx.GraphModule):
  382. raise AssertionError(
  383. f"Expected gm to be a torch.fx.GraphModule, but got {type(gm)}"
  384. )
  385. _fix_input_output_signature(gm, ep.module_call_graph[0].signature)
  386. gm.module_call_graph = ep.module_call_graph
  387. gm.train = types.MethodType(type(gm).train, gm) # type: ignore[assignment]
  388. gm.eval = types.MethodType(type(gm).eval, gm) # type: ignore[assignment]
  389. if not isinstance(gm, torch.fx.GraphModule):
  390. raise AssertionError(
  391. f"Expected gm to be a torch.fx.GraphModule, but got {type(gm)}"
  392. )
  393. gm = _swap_module_helper(gm, modules_to_swap, module_call_graph)
  394. return gm