_fuser.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import warnings
  4. import torch
  5. @contextlib.contextmanager
  6. def optimized_execution(should_optimize):
  7. """Context manager that controls whether the JIT's executor will run optimizations before executing a function."""
  8. stored_flag = torch._C._get_graph_executor_optimize()
  9. torch._C._set_graph_executor_optimize(should_optimize)
  10. try:
  11. yield
  12. finally:
  13. torch._C._set_graph_executor_optimize(stored_flag)
  14. @contextlib.contextmanager
  15. def fuser(name):
  16. """Context manager that facilitates switching between backend fusers.
  17. Valid names:
  18. * ``fuser0`` - enables only legacy fuser
  19. * ``fuser1`` - enables only NNC
  20. * ``fuser2`` - enables only nvFuser
  21. * ``fuser3`` - enables oneDNN Graph
  22. """
  23. old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
  24. old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
  25. old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
  26. old_nvfuser_state = torch._C._jit_nvfuser_enabled()
  27. old_llga_state = torch._C._jit_llga_enabled()
  28. if name == "fuser0": # legacy fuser
  29. torch._C._jit_override_can_fuse_on_cpu(True)
  30. torch._C._jit_override_can_fuse_on_gpu(True)
  31. torch._C._jit_set_texpr_fuser_enabled(False)
  32. torch._C._jit_set_nvfuser_enabled(False)
  33. torch._C._jit_set_llga_enabled(False)
  34. elif name == "fuser1": # NNC
  35. old_profiling_executor = torch._C._jit_set_profiling_executor(True)
  36. old_profiling_mode = torch._C._get_graph_executor_optimize(True)
  37. torch._C._jit_override_can_fuse_on_cpu(True)
  38. torch._C._jit_override_can_fuse_on_gpu(True)
  39. torch._C._jit_set_texpr_fuser_enabled(True)
  40. torch._C._jit_set_nvfuser_enabled(False)
  41. torch._C._jit_set_llga_enabled(False)
  42. elif name == "fuser2": # nvFuser
  43. torch._C._jit_override_can_fuse_on_cpu(False)
  44. torch._C._jit_override_can_fuse_on_gpu(False)
  45. torch._C._jit_set_texpr_fuser_enabled(False)
  46. torch._C._jit_set_nvfuser_enabled(True)
  47. torch._C._jit_set_llga_enabled(False)
  48. elif name == "fuser3": # oneDNN Graph
  49. old_profiling_executor = torch._C._jit_set_profiling_executor(True)
  50. old_profiling_mode = torch._C._get_graph_executor_optimize(True)
  51. torch._C._jit_override_can_fuse_on_cpu(True)
  52. torch._C._jit_override_can_fuse_on_gpu(False)
  53. torch._C._jit_set_texpr_fuser_enabled(True)
  54. torch._C._jit_set_nvfuser_enabled(False)
  55. torch._C._jit_set_llga_enabled(True)
  56. elif name == "none": # Turn Pytorch fuser off
  57. torch._C._jit_override_can_fuse_on_cpu(False)
  58. torch._C._jit_override_can_fuse_on_gpu(False)
  59. torch._C._jit_set_texpr_fuser_enabled(False)
  60. torch._C._jit_set_nvfuser_enabled(False)
  61. torch._C._jit_set_llga_enabled(False)
  62. else:
  63. raise Exception(f"unrecognized fuser option (name: {name})") # noqa: TRY002
  64. try:
  65. yield
  66. finally:
  67. if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph
  68. torch._C._jit_set_profiling_executor(old_profiling_executor) # type: ignore[possibly-undefined]
  69. torch._C._get_graph_executor_optimize(old_profiling_mode) # type: ignore[possibly-undefined]
  70. # recover the previous values
  71. torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
  72. torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
  73. torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
  74. torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
  75. torch._C._jit_set_llga_enabled(old_llga_state)
  76. last_executed_optimized_graph = torch._C._last_executed_optimized_graph
  77. def _get_differentiable_graph_node(node, diff_node) -> None:
  78. if node.kind() == "prim::DifferentiableGraph":
  79. diff_node.append(node)
  80. else:
  81. for block in node.blocks():
  82. for n in block.nodes():
  83. _get_differentiable_graph_node(n, diff_node)
  84. def _graph_for(self, *args, **kwargs):
  85. return _script_method_graph_for(self, self, *args, **kwargs)
  86. def _script_method_graph_for(self, parent, *args, **kwargs):
  87. try:
  88. dbs = parent.get_debug_state()
  89. eps = list(dbs.execution_plans.values())
  90. if len(eps) != 1:
  91. raise AssertionError(f"Expected exactly 1 execution plan, got {len(eps)}")
  92. graph = eps[0].graph.copy()
  93. # graph_executor_states for differentiable node
  94. fw_states = eps[0].code.differentiable_op_executor_states()
  95. diff_nodes: list[torch._C.Node] = []
  96. for n in graph.nodes():
  97. _get_differentiable_graph_node(n, diff_nodes)
  98. if len(fw_states) != len(diff_nodes):
  99. raise AssertionError(
  100. f"Expected fw_states ({len(fw_states)}) and diff_nodes "
  101. f"({len(diff_nodes)}) to have the same length"
  102. )
  103. # swap each differentiable graph with optimized graph in their execution plan
  104. for n, state in zip(diff_nodes, fw_states):
  105. fw_execution_plans = list(state.execution_plans.values())
  106. # we can only update the subgraph when there's a unique execution
  107. # plan. Avoid assert here so we would skip the ones that can't be
  108. # updated while try the best effort to update other nodes.
  109. if len(fw_execution_plans) == 1:
  110. n.g_("Subgraph", fw_execution_plans[0].graph)
  111. return graph
  112. except Exception:
  113. # fallback approach, we just ran the graph and return the recorded optimized
  114. # graph
  115. self(*args, **kwargs)
  116. return last_executed_optimized_graph()
  117. def set_fusion_strategy(strategy: list[tuple[str, int]]):
  118. """Set the type and number of specializations that can occur during fusion.
  119. .. deprecated:: 2.5
  120. TorchScript is deprecated, please use ``torch.compile`` instead.
  121. Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC"
  122. and depth is an integer.
  123. Behavior - static vs dynamic:
  124. In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined
  125. based on some initial profiling runs.
  126. In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple
  127. shapes are possible.
  128. In both cases, we also recompile on new striding behavior, device, or dtype.
  129. Behavior - fallback functions & depth:
  130. When an input doesn't match the format required by the specialized compiled op, it will run
  131. a fallback function. Fallback functions are recursively be compiled and specialized based
  132. on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to
  133. limit the number of specializations that can be compiled, before giving up on recompiling and
  134. falling back to a completely un-fused, un-specialized implementation.
  135. The list of (type, depth) pairs controls the type of specializations and the number of
  136. specializations. For example: [("STATIC", 2), ("DYNAMIC", 2)] indicates that the first
  137. two specializations will use static fusions, the following two specializations will use
  138. dynamic fusion, and any inputs that satisfy none of the 4 options will run an
  139. unfused implementation.
  140. NB: in the future, if more as more fusion backends are added there may be more granular
  141. apis for specific fusers.
  142. """
  143. warnings.warn(
  144. "`torch.jit.set_fusion_strategy` is deprecated. Please use `torch.compile` instead.",
  145. DeprecationWarning,
  146. )
  147. return torch._C._jit_set_fusion_strategy(strategy)