const_fold.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # mypy: allow-untyped-defs
  2. import re
  3. from collections.abc import Callable
  4. from typing import Optional, Union
  5. import torch.fx
  6. from torch.fx.node import map_arg
  7. from torch.fx.passes.split_module import split_module
  8. __all__ = [
  9. "FoldedGraphModule",
  10. "get_unique_attr_name_in_module",
  11. "split_const_subgraphs",
  12. ]
  13. class FoldedGraphModule(torch.fx.GraphModule):
  14. """
  15. FoldedGraphModule is a GraphModule which also contains another
  16. `const_subgraph_module` representing a subgraph which has all const attr
  17. inputs and which can be run once before running the main standard
  18. `graph`. The `const_output_names` are the ordered list names of attrs which
  19. represent what each respective output from the const_subgraph should be set
  20. on which attrs.
  21. """
  22. def __init__(
  23. self,
  24. root: torch.nn.Module,
  25. graph: torch.fx.Graph,
  26. const_subgraph: Optional[torch.fx.Graph] = None,
  27. fx_const_folded_attrs_name: Optional[str] = None,
  28. device_for_folded_attrs: str = "cuda",
  29. ):
  30. super().__init__(root, graph)
  31. self.const_subgraph_module = (
  32. None
  33. if const_subgraph is None
  34. else torch.fx.GraphModule(root, const_subgraph)
  35. )
  36. self.has_folding_been_run = False
  37. self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
  38. self.device_for_folded_attrs = device_for_folded_attrs
  39. def __call__(self, *args, **kwargs):
  40. if not self.has_folding_been_run:
  41. self.run_folding()
  42. return super().__call__(*args)
  43. def run_folding(self):
  44. # If there's no const subgraph module or attr output names to use, return
  45. # early as there is no const folding to perform.
  46. if (
  47. self.const_subgraph_module is None
  48. or self.fx_const_folded_attrs_name is None
  49. ):
  50. return
  51. if self.has_folding_been_run:
  52. raise AssertionError("Folding has already been run")
  53. self.has_folding_been_run = True
  54. # Actually run const folding subgraph. Note that single attr const fold
  55. # subgraphs output a single Tensor while multiple outputs are returned as
  56. # Tuple[Tensor,].
  57. folded_attrs = self.const_subgraph_module()
  58. def _create_param(i):
  59. return torch.nn.Parameter(
  60. i.detach().clone()
  61. if not isinstance(i, int)
  62. else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
  63. requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
  64. )
  65. params = (
  66. torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
  67. if isinstance(folded_attrs, tuple)
  68. else _create_param(folded_attrs)
  69. )
  70. setattr(self, self.fx_const_folded_attrs_name, params)
  71. def _inline_module(
  72. gm: torch.fx.GraphModule, inline_mod_name: str, run_dce: bool = True
  73. ) -> dict[torch.fx.Node, torch.fx.Node]:
  74. """
  75. Given `gm` and some graph module which is called with target name `inline_mod_name`,
  76. this helper will inline all of the nodes from that called graph module into `gm`.
  77. Returns a mapping from subgraph nodes to the newly created/mapped nodes in gm.
  78. """
  79. # Fetch the inner graph module that we want to inline inside `gm`.
  80. inline_mod = dict(gm.named_modules())[inline_mod_name]
  81. if not isinstance(inline_mod, torch.fx.GraphModule):
  82. raise AssertionError(f"Expected GraphModule, got {type(inline_mod)}")
  83. call_mod_node_to_replace = None
  84. for node in gm.graph.nodes:
  85. if node.op == "call_module" and node.target == inline_mod_name:
  86. call_mod_node_to_replace = node
  87. break
  88. if call_mod_node_to_replace is None:
  89. raise AssertionError(f"Could not find call_module node for {inline_mod_name}")
  90. # Now actually do the swap. Note that we have to keep track of new nodes that are
  91. # copied into `gm` -- we do this via replacement_mapping.
  92. call_mod_args = call_mod_node_to_replace.args
  93. call_mod_kwargs = call_mod_node_to_replace.kwargs
  94. replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {}
  95. ph_count = 0
  96. def replacement_fn(node):
  97. new_node = replacement_mapping[node]
  98. new_node.meta = node.meta.copy()
  99. return new_node
  100. for inline_node in inline_mod.graph.nodes:
  101. if inline_node.op == "placeholder":
  102. replacement_mapping[inline_node] = (
  103. call_mod_kwargs[inline_node.name]
  104. if inline_node.name in call_mod_kwargs
  105. else call_mod_args[ph_count]
  106. )
  107. ph_count += 1
  108. continue
  109. if inline_node.op == "output":
  110. outputs = inline_node.args[0]
  111. output_replacements = map_arg(outputs, replacement_fn)
  112. # If output is a tuple, we need to handle getitem users specially.
  113. # Capture users before replace_all_uses_with modifies them.
  114. getitem_users: list[torch.fx.Node] = []
  115. if isinstance(output_replacements, (list, tuple)):
  116. import operator
  117. getitem_users = [
  118. user
  119. for user in call_mod_node_to_replace.users
  120. if user.op == "call_function"
  121. and user.target is operator.getitem
  122. and isinstance(user.args[1], int)
  123. ]
  124. call_mod_node_to_replace.replace_all_uses_with(output_replacements)
  125. # Inline getitem nodes that now index into the tuple literal
  126. for user in getitem_users:
  127. idx = user.args[1]
  128. if not isinstance(idx, int):
  129. raise AssertionError(f"Expected int index, got {type(idx)}")
  130. user.replace_all_uses_with(output_replacements[idx])
  131. gm.graph.erase_node(user)
  132. replacement_mapping[user] = output_replacements[idx]
  133. continue
  134. with gm.graph.inserting_before(call_mod_node_to_replace):
  135. new_node = gm.graph.node_copy(inline_node, replacement_fn)
  136. replacement_mapping[inline_node] = new_node
  137. # Explicitly remove the module that was just inlined,
  138. # this module may contain impure ops so cannot be dead code eliminated,
  139. # this module is unneeded as it's just inlined back to main graph.
  140. gm.graph.erase_node(call_mod_node_to_replace)
  141. if run_dce:
  142. gm.graph.eliminate_dead_code()
  143. return replacement_mapping
  144. def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str:
  145. """
  146. Make sure the name is unique (in a module) and can represents an attr.
  147. """
  148. # Delete all characters that are illegal in a Python identifier.
  149. name = re.sub("[^0-9a-zA-Z_]+", "_", name)
  150. if name[0].isdigit():
  151. name = f"_{name}"
  152. # Now make sure it is in fact unique to the module by incrementing suffix value.
  153. while hasattr(mod_traced, name):
  154. match = re.match(r"(.*)_(\d+)$", name)
  155. if match is None:
  156. name = name + "_1"
  157. else:
  158. base, num = match.group(1, 2)
  159. name = f"{base}_{int(num) + 1}"
  160. return name
  161. def split_const_subgraphs(
  162. module: Union[torch.nn.Module, torch.fx.GraphModule],
  163. skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
  164. device_for_folded_attrs: str = "cpu",
  165. ) -> FoldedGraphModule:
  166. """
  167. Looks through `module` for any nodes that have all constant attribute inputs
  168. and separates them out into their own constant subgraph, and returns a
  169. FoldedGraphModule which runs that constant subgraph on the first run to set
  170. attributes on the module prior to running the non-constant portion of the
  171. graph.
  172. """
  173. import sympy
  174. if not isinstance(module, torch.fx.GraphModule):
  175. mod_traced = torch.fx.symbolic_trace(module)
  176. else:
  177. mod_traced = module
  178. def _subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool:
  179. """
  180. Return True if a GraphModule type subgraph contains any impure op, else False.
  181. """
  182. if not isinstance(module, torch.fx.GraphModule):
  183. raise AssertionError(
  184. "caller should only pass GraphModule to subgraph_has_impure_ops check"
  185. )
  186. for node in module.graph.nodes:
  187. if node.op == "call_function" and node.is_impure():
  188. return True
  189. if (
  190. node.op == "call_module"
  191. # pyrefly: ignore [not-callable]
  192. and (submodule := module.get_submodule(node.target))
  193. and isinstance(submodule, torch.fx.GraphModule)
  194. ):
  195. return _subgraph_has_impure_ops(submodule)
  196. return False
  197. # Build up a list of const_nodes, defined as nodes that are themselves
  198. # get_attrs, or have all get_attr or other constant node inputs.
  199. const_nodes: set[torch.fx.Node] = set()
  200. found_const_folding = False
  201. for node in mod_traced.graph.nodes:
  202. # Skip over placeholders/outputs because they can't be const folded and
  203. # we don't want to add tags to them.
  204. if node.op in {"placeholder", "output"}:
  205. continue
  206. # If the node itself is constant, or all of its inputs are constant,
  207. # then tag it as constant.
  208. if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
  209. const_nodes
  210. ):
  211. continue
  212. # If provided skip folding function says to skip, then skip.
  213. if skip_folding_node_fn and skip_folding_node_fn(node):
  214. continue
  215. # Skip folding side-effectful functions
  216. if node.is_impure():
  217. continue
  218. # Skip folding nodes that have symbolic fill_value
  219. if isinstance(node.kwargs.get("fill_value", None), sympy.Expr):
  220. continue
  221. # Skip folding submodules that have impure ops
  222. if (
  223. node.op == "call_module"
  224. # pyrefly: ignore [not-callable]
  225. and (target_mod := mod_traced.get_submodule(node.target))
  226. and isinstance(target_mod, torch.fx.GraphModule)
  227. and _subgraph_has_impure_ops(target_mod)
  228. ):
  229. continue
  230. # Must be a constant foldable node at this point.
  231. const_nodes.add(node)
  232. if node.op != "get_attr":
  233. found_const_folding = True
  234. # If we did not find any const folding then return early without a const fold subgraph.
  235. if not found_const_folding:
  236. return FoldedGraphModule(mod_traced, mod_traced.graph)
  237. # Partition the module into two: submod_0 for constant folding subgraph, and
  238. # submod_1 for the rest.
  239. def mod_partition(node: torch.fx.Node):
  240. return 0 if node in const_nodes else 1
  241. split = split_module(mod_traced, module, mod_partition)
  242. const_mod_name, non_const_mod_name = "submod_0", "submod_1"
  243. # Safely get submod_1 in case there are no non-const nodes
  244. const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None)
  245. # The module that a call_module node refers to gets copied to submodules during split.
  246. # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to
  247. # attach inlined modules to `split` as it's the owning module now.
  248. for node in non_const_gm.graph.nodes if non_const_gm else []:
  249. if node.op == "call_module":
  250. setattr(split, node.target, getattr(non_const_gm, node.target))
  251. for node in const_gm.graph.nodes:
  252. if node.op == "call_module":
  253. setattr(split, node.target, getattr(const_gm, node.target))
  254. # split_module currently does not use get_attrs for attrs. Instead it passes
  255. # them in as args from the parent module, which used get_attrs. Here we set
  256. # them as get_attrs inside const_gm, allowing for running folding without
  257. # somehow a priori knowing the attrs that should be passed as args. We can
  258. # unconditionally do this for all placeholders because we know all
  259. # placeholders to const_gm must be constants accessible via get_attr.
  260. call_const_gm_args = None
  261. for node in split.graph.nodes:
  262. if node.op == "call_module":
  263. if node.target == const_mod_name:
  264. call_const_gm_args = node.args
  265. break
  266. if call_const_gm_args is None:
  267. raise AssertionError("Could not find call_module node for const_gm")
  268. # Here we do the actual replacement of placeholders to get_attrs. Note that here we
  269. # set the const_gm.graph into a new root_const_gm with split as the root module,
  270. # because we are fetching attributes directly from the root module, instead of
  271. # fetching them from const_gm. Example: The const_gm must have some format like:
  272. # graph():
  273. # %inp : [num_users=1] = placeholder[target=const_inp]
  274. # %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {})
  275. # return add
  276. # We replace that with the following, which does not have any placeholders:
  277. # graph():
  278. # %inp_1 : [num_users=1] = get_attr[target=const_inp]
  279. # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
  280. # return add
  281. root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
  282. # The order of placeholders in the const_gm graph should match the order of
  283. # args in the outer module, so we can simply use an index for the
  284. # placeholder mapping
  285. ph_idx = 0
  286. for node in root_const_gm.graph.nodes:
  287. if node.op == "output":
  288. multiple_outputs = isinstance(node.args[0], tuple)
  289. continue
  290. if node.op != "placeholder":
  291. continue
  292. if ph_idx >= len(call_const_gm_args):
  293. raise AssertionError(
  294. f"Placeholder index {ph_idx} out of range for args "
  295. f"(len={len(call_const_gm_args)})"
  296. )
  297. in_node = call_const_gm_args[ph_idx]
  298. ph_idx += 1
  299. if in_node.op != "get_attr":
  300. raise AssertionError(f"Expected get_attr, got {in_node.op}")
  301. with root_const_gm.graph.inserting_before(node):
  302. new_node = root_const_gm.graph.get_attr(in_node.target)
  303. new_node.meta = node.meta.copy()
  304. node.replace_all_uses_with(new_node)
  305. root_const_gm.graph.erase_node(node)
  306. if "multiple_outputs" not in locals():
  307. raise AssertionError("multiple_outputs not set in loop")
  308. # Now find the call to const_gm inside split, and replace it with a getattr to the
  309. # folded tensor(s) that result from constant folding. Note that we don't need to
  310. # worry about whether this is one or more tensors because the original graph
  311. # correctly uses getitem to extract individual tensors if there are multiple folded.
  312. fx_const_folded_attrs_name = get_unique_attr_name_in_module(
  313. mod_traced, "_FX_CONST_FOLDED_ATTRS"
  314. )
  315. setattr(
  316. split,
  317. fx_const_folded_attrs_name,
  318. torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined]
  319. )
  320. for node in split.graph.nodes:
  321. if node.op == "call_module" and node.target == const_mod_name:
  322. with node.graph.inserting_before(node):
  323. folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name)
  324. folded_attrs.meta = node.meta.copy()
  325. node.replace_all_uses_with(folded_attrs)
  326. break
  327. # Finally, inline the non-constant submod (if it exists) into the split submod.
  328. # This is so that the original caller who may have passed in a graph module will
  329. # get back out a graph module whose graph is traced to the same granularity.
  330. if hasattr(split, non_const_mod_name):
  331. _inline_module(split, non_const_mod_name)
  332. split.graph.eliminate_dead_code()
  333. return FoldedGraphModule(
  334. split,
  335. split.graph,
  336. root_const_gm.graph,
  337. fx_const_folded_attrs_name,
  338. device_for_folded_attrs,
  339. )