compilers.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. from __future__ import annotations
  2. import copy
  3. import logging
  4. import os
  5. import pickle
  6. import random
  7. from contextlib import contextmanager
  8. from functools import partial
  9. from typing import Any, TYPE_CHECKING
  10. from typing_extensions import ParamSpec, TypeVar
  11. import sympy
  12. import torch
  13. import torch.fx as fx
  14. import torch.nn as nn
  15. import torch.utils._pytree as pytree
  16. from torch import SymInt
  17. from torch._decomp import get_decompositions
  18. from torch.fx.experimental.symbolic_shapes import bind_symbols
  19. from .aot_autograd import aot_function, aot_module, make_boxed_compiler
  20. from .compile_utils import strip_overloads
  21. from .partitioners import (
  22. default_partition,
  23. draw_graph,
  24. min_cut_rematerialization_partition,
  25. )
  26. if TYPE_CHECKING:
  27. from collections.abc import Callable, Generator, Sequence
  28. from torch.fx.node import Node
  29. from torch.types import IntLikeType
  30. _P = ParamSpec("_P")
  31. _R = TypeVar("_R")
  32. log = logging.getLogger(__name__)
  33. # These canonicalization are needed here (and not decompositions), as the ops
  34. # we're trying to canonicalize to CompositeImplicitAutograd.
  35. def _canonicalize(fx_g: fx.GraphModule) -> fx.GraphModule:
  36. for node in fx_g.graph.find_nodes(
  37. op="call_function", target=torch.ops.aten._to_copy
  38. ):
  39. node.target = torch.ops.aten.to
  40. fx_g.recompile()
  41. return fx_g
  42. @contextmanager
  43. def _disable_jit_autocast() -> Generator[None, None, None]:
  44. # pyrefly: ignore [missing-attribute]
  45. old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
  46. try:
  47. yield
  48. finally:
  49. # pyrefly: ignore [missing-attribute]
  50. torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
  51. @make_boxed_compiler
  52. def ts_compile(fx_g: fx.GraphModule, inps: Sequence[Any]) -> torch.jit.ScriptModule:
  53. """
  54. Compiles the :attr:`fx_g` with Torchscript compiler.
  55. .. warning::
  56. This API is experimental and likely to change.
  57. Args:
  58. fx_g(fx.GraphModule): The input Fx graph module to be compiled.
  59. Returns:
  60. Torch scripted model.
  61. """
  62. with _disable_jit_autocast():
  63. strip_overloads(fx_g)
  64. for node in fx_g.graph.find_nodes(
  65. op="call_function", target=torch.ops.aten._to_copy
  66. ):
  67. if len(node.args) == 1 and len(node.kwargs) == 1 and "dtype" in node.kwargs:
  68. node.target = torch.ops.aten.to
  69. for node in fx_g.graph.nodes:
  70. new_kwargs = {}
  71. for k, v in node.kwargs.items():
  72. if isinstance(v, torch.device):
  73. v = v.type
  74. new_kwargs[k] = v
  75. node.kwargs = new_kwargs
  76. fx_g.graph.lint()
  77. fx_g.recompile()
  78. f = torch.jit.script(fx_g)
  79. # pyrefly: ignore [missing-attribute]
  80. torch._C._jit_pass_remove_mutation(f.graph)
  81. f = torch.jit.freeze(f.eval())
  82. f = torch.jit.optimize_for_inference(f)
  83. if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps):
  84. f(*inps)
  85. return f
  86. def _draw_graph_compile(
  87. fx_g: fx.GraphModule, _: Any, name: str, clear_meta: bool = True
  88. ) -> fx.GraphModule:
  89. print(fx_g.code)
  90. draw_graph(fx_g, name, clear_meta=clear_meta)
  91. return fx_g
  92. def draw_graph_compile(
  93. name: str,
  94. ) -> Callable[[fx.GraphModule, list[Any]], fx.GraphModule]:
  95. return make_boxed_compiler(partial(_draw_graph_compile, name=name))
  96. @make_boxed_compiler
  97. def nop(fx_g: fx.GraphModule, _: Any) -> fx.GraphModule:
  98. """
  99. Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler
  100. and can be used to check accuracy.
  101. .. warning::
  102. This API is experimental and likely to change.
  103. """
  104. return fx_g
  105. class DebugInterpreter(fx.Interpreter):
  106. def run(
  107. self,
  108. *args: Any,
  109. initial_env: dict[Node, Any] | None = None,
  110. enable_io_processing: bool = True,
  111. ) -> Any:
  112. self.symbol_mapping = bind_symbols(
  113. # pyrefly: ignore[bad-argument-type]
  114. self.module,
  115. *args,
  116. )
  117. return super().run(
  118. *args, initial_env=initial_env, enable_io_processing=enable_io_processing
  119. )
  120. def run_node(self, n: Node) -> Any:
  121. def subst_symint(ni: IntLikeType) -> int:
  122. if not isinstance(ni, SymInt):
  123. return ni
  124. r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping))
  125. if not r.is_number:
  126. raise AssertionError(f"expected r to be a number, got {r}")
  127. return int(r)
  128. def subst_symint_tuple(nis: tuple[IntLikeType, ...]) -> tuple[int, ...]:
  129. return tuple(subst_symint(ni) for ni in nis)
  130. def check_significant_strides(a: torch.Tensor, b: torch.Tensor) -> bool:
  131. if subst_symint(a.numel()) > 0:
  132. for idx in range(a.ndim):
  133. if (
  134. subst_symint(a.stride(idx)) != b.stride(idx)
  135. and subst_symint(a.size(idx)) > 1
  136. ):
  137. return False
  138. return True
  139. def check(nv: torch.Tensor, rv: torch.Tensor, desc: Callable[[], str]) -> None:
  140. if not callable(desc):
  141. raise AssertionError(f"expected desc to be callable, got {type(desc)}")
  142. if nv.dtype != rv.dtype:
  143. raise AssertionError(f"{desc()}: {nv.dtype} != {rv.dtype}")
  144. if subst_symint_tuple(nv.size()) != rv.size():
  145. raise AssertionError(
  146. f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
  147. )
  148. same_strides = check_significant_strides(nv, rv)
  149. if not same_strides:
  150. raise AssertionError(
  151. f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"
  152. )
  153. r = super().run_node(n)
  154. if "val" in n.meta:
  155. n_vals, _n_spec = pytree.tree_flatten(n.meta["val"])
  156. r_vals, _r_spec = pytree.tree_flatten(r)
  157. # TODO: There is some sort of problem where we record that an
  158. # operator returned a tuple/list, and then later it turns out the
  159. # real version of the operator returned a list/tuple. Need to
  160. # figure out what's actually going on here, the error itself is
  161. # harmless enough as we only getitem out the outputs.
  162. # assert n_spec == r_spec, f"{n_spec} != {r_spec}"
  163. if len(n_vals) != len(r_vals):
  164. raise AssertionError(f"{len(n_vals)} != {len(r_vals)}")
  165. for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
  166. if not isinstance(rv, torch.Tensor):
  167. continue
  168. check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}")
  169. return r
  170. @make_boxed_compiler
  171. def debug_nop(
  172. fx_g: fx.GraphModule, _: Any
  173. ) -> Callable[[DebugInterpreter, Any, dict[Node, Any] | None, bool], Any]:
  174. """
  175. Returns a (slow) interpreter over the FX graph module that also checks
  176. various debugging properties (e.g., that tracing strides matched real
  177. strides.)
  178. """
  179. return DebugInterpreter(fx_g).run
  180. @make_boxed_compiler
  181. def simple_ts_compile(fx_g: fx.GraphModule, _: Any) -> torch.jit.ScriptModule:
  182. strip_overloads(fx_g)
  183. f = torch.jit.script(fx_g)
  184. f = torch.jit.freeze(f.eval())
  185. return f
  186. def nnc_jit(f: Callable[..., Any]) -> Callable[..., Any]:
  187. return aot_function(f, simple_ts_compile)
  188. aten = torch.ops.aten
  189. default_decompositions = {
  190. aten.detach,
  191. aten.gelu_backward,
  192. aten.leaky_relu_backward,
  193. aten.sigmoid_backward,
  194. aten.threshold_backward,
  195. aten.hardtanh_backward,
  196. aten.hardsigmoid_backward,
  197. aten.hardswish_backward,
  198. aten.tanh_backward,
  199. aten.silu_backward,
  200. aten.elu_backward,
  201. aten.cudnn_batch_norm,
  202. aten.cudnn_batch_norm_backward,
  203. aten.masked_fill.Scalar,
  204. aten.masked_fill.Tensor,
  205. aten.elu,
  206. aten.leaky_relu,
  207. aten.hardtanh,
  208. aten.hardswish,
  209. aten.hardsigmoid,
  210. aten.conj_physical,
  211. aten.is_same_size,
  212. }
  213. # pyrefly: ignore[bad-argument-type]
  214. default_decompositions = get_decompositions(default_decompositions)
  215. @make_boxed_compiler
  216. def print_compile(fx_g: fx.GraphModule, _: Any) -> fx.GraphModule:
  217. print(fx_g.code)
  218. return fx_g
  219. def memory_efficient_fusion(
  220. fn: Callable[_P, _R] | nn.Module,
  221. **kwargs: Any,
  222. ) -> Callable[_P, _R] | nn.Module:
  223. """
  224. Wrapper function over :func:`aot_function` and :func:`aot_module` to perform
  225. memory efficient fusion. It uses the
  226. :func:`min_cut_rematerialization_partition` partitioner to perform efficient
  227. recomputation. It uses NVFuser to compile the generated forward and backward
  228. graphs.
  229. .. warning::
  230. This API is experimental and likely to change.
  231. Args:
  232. fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module``
  233. that takes one or more arguments. Must return one or more Tensors.
  234. **kwargs: Any other overrides you want to make to the settings
  235. Returns:
  236. Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior
  237. of the original :attr:`fn`, but whose forward and backward graphs have
  238. gone through recomputation optimizations, and the graphs have been
  239. compiled with nvfuser.
  240. """
  241. config = {
  242. "fw_compiler": ts_compile,
  243. "bw_compiler": ts_compile,
  244. "partition_fn": min_cut_rematerialization_partition,
  245. "decompositions": default_decompositions,
  246. }
  247. config.update(kwargs)
  248. if isinstance(fn, torch.nn.Module):
  249. return aot_module(fn, **config) # pyrefly: ignore[bad-argument-type]
  250. else:
  251. return aot_function(fn, **config) # pyrefly: ignore[bad-argument-type]
  252. def debug_compile(
  253. fx_g: fx.GraphModule, inps: Sequence[torch.Tensor]
  254. ) -> torch.jit.ScriptModule:
  255. fx_g.to_folder("foo")
  256. print(
  257. f"""
  258. ##############################################################
  259. # To minimize FX graph, copy and paste the below and run it #
  260. ##############################################################
  261. import torch
  262. import torch.fx as fx
  263. from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
  264. inps = {[(i.shape, i.dtype) for i in inps]}
  265. inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
  266. from foo import FxModule
  267. mod = FxModule().cuda()
  268. with torch.jit.fuser("fuser2"):
  269. # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess
  270. minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
  271. """
  272. )
  273. from foo import FxModule # pyrefly: ignore[missing-import]
  274. FxModule().cuda()(*inps)
  275. return ts_compile(fx_g, inps)
  276. graph_index: int = 0
  277. def get_inputs(input_data_path: str) -> list[torch.Tensor]:
  278. """
  279. Return a random input for the given inputs meta generated from _save_fx_default.
  280. """
  281. inputs: list[torch.Tensor] = []
  282. with open(input_data_path, "rb") as f:
  283. inputs_meta = pickle.load(f)
  284. inputs = []
  285. for meta in inputs_meta:
  286. if len(meta) == 1:
  287. type = meta
  288. input_ = type(random.random())
  289. else:
  290. type, shape, _stride, dtype, device = meta
  291. if dtype in {
  292. torch.int,
  293. torch.int32,
  294. torch.int64,
  295. torch.bool,
  296. torch.int,
  297. torch.uint8,
  298. int,
  299. float,
  300. }:
  301. input_ = torch.randint(0, 1, shape, dtype=dtype, device=device)
  302. else:
  303. input_ = torch.rand(shape, dtype=dtype, device=device)
  304. inputs.append(input_)
  305. return inputs
  306. def _save_fx_default(
  307. current_name: str,
  308. folder_name: str,
  309. dump_example_input: bool,
  310. gm: torch.fx.GraphModule,
  311. example_inputs: list[torch.Tensor],
  312. ) -> nn.Module:
  313. """
  314. The forward, backward, and joint computation graph will be stored in
  315. {folder_name}/{current_name}/{current_name}_forward_{graph_index},
  316. {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and
  317. {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively.
  318. The input shape of the graphs will be stored in the .input files.
  319. These files can be loaded with pickle,
  320. and is a list of format (type, shape, stride, dtype, device).
  321. In the case of type = int or float, it is just (type,).
  322. For joint graph input, it is a nested list [[],[]]
  323. where the two inner lists have the same format.
  324. If dump_example_input is True, example_inputs will be stored in .pt file.
  325. Since each function might produce multiple graphs,
  326. the graph_index is used to distinguish difference graphs
  327. """
  328. from functorch.compile import aot_module_simplified
  329. def get_input_meta(args: Any) -> list[Any]:
  330. input_meta = []
  331. if len(args) > 0 and isinstance(args[0], tuple): # joint input
  332. input_meta += get_input_meta(args[0])
  333. input_meta += get_input_meta(args[1])
  334. return input_meta
  335. for arg in args:
  336. if type(arg) is int or type(arg) is float:
  337. input_meta.append((type(arg),))
  338. else:
  339. input_meta.append(
  340. (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)
  341. )
  342. return input_meta
  343. def graph_saver_helper(
  344. gm_to_save: fx.GraphModule, args: Any, type_name: str
  345. ) -> None:
  346. global graph_index
  347. if len(gm_to_save.graph.nodes) == 0:
  348. log.log(
  349. logging.WARNING,
  350. "No nodes in graph {%s}_{%s}_{%s}.",
  351. current_name,
  352. type_name,
  353. graph_index,
  354. )
  355. return
  356. gm = copy.deepcopy(gm_to_save)
  357. gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen
  358. gm.recompile()
  359. input_meta = get_input_meta(args)
  360. os.makedirs(f"{folder_name}/{current_name}", exist_ok=True)
  361. gm.to_folder(
  362. f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}"
  363. )
  364. with open(
  365. f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input",
  366. "wb",
  367. ) as f:
  368. pickle.dump(input_meta, f)
  369. if dump_example_input:
  370. torch.save(
  371. args,
  372. f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950
  373. ) # noqa: E501
  374. def graph_saver_forward(
  375. gm: fx.GraphModule, example_inputs: list[torch.Tensor]
  376. ) -> fx.GraphModule:
  377. graph_saver_helper(gm, example_inputs, "forward")
  378. return gm
  379. def graph_saver_backward(
  380. gm: fx.GraphModule, example_inputs: list[torch.Tensor]
  381. ) -> fx.GraphModule:
  382. graph_saver_helper(gm, example_inputs, "backward")
  383. global graph_index
  384. graph_index += 1
  385. return gm
  386. def graph_saver_joint(
  387. gm: fx.GraphModule, joint_args: list[torch.Tensor]
  388. ) -> tuple[fx.GraphModule, fx.GraphModule]:
  389. graph_saver_helper(gm, joint_args, "joint")
  390. return default_partition(gm, joint_args) # pyrefly: ignore[missing-argument]
  391. # pyrefly: ignore[bad-return]
  392. return aot_module_simplified(
  393. gm,
  394. example_inputs,
  395. fw_compiler=graph_saver_forward, # pyrefly: ignore[bad-argument-type]
  396. bw_compiler=graph_saver_backward, # pyrefly: ignore[bad-argument-type]
  397. partition_fn=graph_saver_joint,
  398. decompositions=default_decompositions, # pyrefly: ignore[bad-argument-type]
  399. )
  400. # WARNING: This isn't tested anywhere!!
  401. def graph_dumper_aot(
  402. current_name: str, folder_name: str, dump_example_input: bool = False
  403. ) -> Callable[[bool, nn.Module], Any]:
  404. """
  405. Dump the forward, backward, and joint computation graph.
  406. Example Usage:
  407. save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False)
  408. optimize_ctx = torchdynamo.optimize(
  409. save_fx_func
  410. )
  411. with torch.enable_grad():
  412. with optimize_ctx:
  413. result = forward_and_backward_pass(model, example_inputs)
  414. """
  415. global graph_index
  416. graph_index = 0
  417. return partial(_save_fx_default, current_name, folder_name, dump_example_input)