fuse.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from collections.abc import Callable
  4. from typing import Any
  5. from torch.ao.quantization.backend_config import (
  6. BackendConfig,
  7. get_native_backend_config,
  8. )
  9. from torch.ao.quantization.backend_config.utils import (
  10. get_fuser_method_mapping,
  11. get_fusion_pattern_to_extra_inputs_getter,
  12. get_fusion_pattern_to_root_node_getter,
  13. )
  14. from torch.ao.quantization.utils import NodePattern, Pattern
  15. from torch.fx import GraphModule, map_arg, Node
  16. from torch.fx.graph import Graph
  17. from .custom_config import FuseCustomConfig
  18. from .fuse_handler import _get_fusion_pattern_to_fuse_handler_cls, FuseHandler
  19. from .match_utils import _is_match, MatchAllNode
  20. from .pattern_utils import _sorted_patterns_dict
  21. __all__ = [
  22. "fuse",
  23. # TODO: We should make this private in the future
  24. # This is currently needed for test_public_bindings for some reason
  25. "FuseHandler",
  26. ]
  27. def fuse(
  28. model: GraphModule,
  29. is_qat: bool,
  30. fuse_custom_config: FuseCustomConfig | dict[str, Any] | None = None,
  31. backend_config: BackendConfig | dict[str, Any] | None = None,
  32. ) -> GraphModule:
  33. if fuse_custom_config is None:
  34. fuse_custom_config = FuseCustomConfig()
  35. if isinstance(fuse_custom_config, dict):
  36. warnings.warn(
  37. "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
  38. "in a future version. Please pass in a FuseCustomConfig instead.",
  39. FutureWarning,
  40. stacklevel=2,
  41. )
  42. fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
  43. if isinstance(backend_config, dict):
  44. warnings.warn(
  45. "Passing a backend_config_dict to prepare is deprecated and will not be supported "
  46. "in a future version. Please pass in a BackendConfig instead.",
  47. FutureWarning,
  48. stacklevel=2,
  49. )
  50. backend_config = BackendConfig.from_dict(backend_config)
  51. named_modules = dict(model.named_modules())
  52. if backend_config is None:
  53. backend_config = get_native_backend_config()
  54. fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(
  55. _get_fusion_pattern_to_fuse_handler_cls(backend_config)
  56. )
  57. fuser_method_mapping = get_fuser_method_mapping(backend_config)
  58. fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(
  59. backend_config
  60. )
  61. fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(
  62. backend_config
  63. )
  64. # find fusion
  65. fusion_pairs = _find_matches(model, model.graph, fusion_pattern_to_fuse_handler_cls)
  66. # TODO: change this to inplace changes to graph, since we no longer construct
  67. # new GraphModule anymore
  68. fused_graph = Graph()
  69. env: dict[Any, Any] = {}
  70. def load_arg(a):
  71. return map_arg(a, lambda node: env[node.name])
  72. def default_root_node_getter(node_pattern):
  73. while not isinstance(node_pattern[-1], Node):
  74. node_pattern = node_pattern[-1]
  75. return node_pattern[-1]
  76. for node in model.graph.nodes:
  77. (
  78. maybe_last_node,
  79. pattern,
  80. matched_node_pattern,
  81. obj,
  82. node_to_subpattern,
  83. ) = fusion_pairs.get(node.name, (None, None, None, None, None))
  84. # get the corresponding subpattern for the current node
  85. if node_to_subpattern is not None:
  86. node_subpattern = node_to_subpattern.get(node, None)
  87. else:
  88. node_subpattern = None
  89. if maybe_last_node is node:
  90. if obj is None:
  91. raise AssertionError(
  92. "fuse handler object must not be None for matched root node"
  93. )
  94. root_node_getter = fusion_pattern_to_root_node_getter.get(
  95. pattern, default_root_node_getter
  96. )
  97. root_node = root_node_getter(matched_node_pattern) # type: ignore[index]
  98. extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(
  99. pattern, None
  100. )
  101. extra_inputs = []
  102. if extra_inputs_getter is not None:
  103. extra_inputs = extra_inputs_getter(matched_node_pattern)
  104. # TODO: add validation that root_node is a module and has the same type
  105. # as the root_module in the configuration
  106. env[node.name] = obj.fuse(
  107. load_arg,
  108. named_modules,
  109. fused_graph,
  110. root_node,
  111. extra_inputs,
  112. matched_node_pattern, # type: ignore[arg-type]
  113. fuse_custom_config,
  114. fuser_method_mapping,
  115. is_qat,
  116. )
  117. elif maybe_last_node is None or node_subpattern is MatchAllNode:
  118. env[node.name] = fused_graph.node_copy(node, load_arg)
  119. # node matched in patterns and is not root is removed here
  120. model = GraphModule(model, fused_graph)
  121. return model
  122. def _find_matches(
  123. root: GraphModule,
  124. graph: Graph,
  125. pattern_to_fuse_handler_cls: dict[Pattern, Callable],
  126. ) -> dict[str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]]]:
  127. modules = dict(root.named_modules())
  128. # node name -> (root_node, match_value)
  129. match_map: dict[
  130. str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]]
  131. ] = {}
  132. # a map from node to the matched subpattern
  133. node_to_subpattern: dict[Node, Any] = {}
  134. # TODO: dedup with quantization matching function in match_utils.py
  135. def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
  136. if isinstance(pattern, tuple):
  137. s, *args = pattern
  138. current_node_pattern: list[Node] = []
  139. apply_match(s, node, match, current_node_pattern, node_to_subpattern)
  140. for subpattern, arg in zip(args, node.args):
  141. apply_match(
  142. subpattern, arg, match, current_node_pattern, node_to_subpattern
  143. )
  144. matched_node_pattern.append(tuple(current_node_pattern))
  145. else:
  146. # the first pattern matches will take precedence
  147. if node.name not in match_map:
  148. matched_node_pattern.append(node)
  149. # MatchAllNode here is actually MatchAllInputNode which should not
  150. # be added to match_map
  151. if pattern is not MatchAllNode:
  152. node_to_subpattern[node] = pattern
  153. root_node, pattern, handler = match
  154. match_map[node.name] = (
  155. root_node,
  156. pattern,
  157. matched_node_pattern,
  158. handler,
  159. node_to_subpattern,
  160. )
  161. for node in reversed(graph.nodes):
  162. if node.name not in match_map:
  163. for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
  164. matched_node_pattern: list[Node] = []
  165. if _is_match(modules, node, pattern):
  166. apply_match(
  167. pattern,
  168. node,
  169. (node, pattern, fuse_handler_cls(node)),
  170. matched_node_pattern,
  171. node_to_subpattern,
  172. )
  173. break
  174. return match_map