merge_matmul.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. import operator
  4. import torch
  5. from torch.fx._symbolic_trace import symbolic_trace
  6. from torch.fx.node import Node
  7. from torch.fx.passes.tools_common import legalize_graph
  8. def split_result_tensors(
  9. result: torch.Tensor, inputs: list[torch.Tensor]
  10. ) -> tuple[torch.Tensor, ...]:
  11. """
  12. A free function for use in the merge_matmul graph transformation below that
  13. splits the output from a merged matmul into the individual results for each
  14. input tensor.
  15. Arguments:
  16. result: The merged matmul result tensor.
  17. inputs: The list of inputs that were merged into one for the matmul.
  18. Returns:
  19. List of matmul results for each input tensor.
  20. """
  21. # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
  22. # need an int even when tracing
  23. if isinstance(result, torch.fx.Proxy):
  24. splits = [0] * len(inputs)
  25. else:
  26. splits = [x.shape[0] for x in inputs]
  27. return torch.split(result, splits)
  28. def may_depend_on(a: Node, b: Node, search_depth: int = 6):
  29. """
  30. Determine if one node depends on another in a torch.fx.Graph.
  31. Arguments:
  32. a: The node that may have a dependency on b.
  33. b: The node that a may have a dependency on.
  34. search_depth: In the case of an indirect dependency, this function
  35. searches upto this many nodes away in search of a
  36. data dependency. If none is found, the function
  37. makes the conservative assumption that there is a
  38. dependency.
  39. Returns:
  40. True if a may depend on b, False if it definitely does not.
  41. """
  42. # Equivalence is defined as dependence.
  43. if a == b:
  44. return True
  45. # If a has no inputs, it cannot depend on b.
  46. if len(a.all_input_nodes) == 0:
  47. return False
  48. # If the search depth has been exhausted and no conclusion has been
  49. # reached, assume that there is a data dependency.
  50. if search_depth == 0:
  51. return True
  52. # Recursively check all inputs of a.
  53. for inp in a.all_input_nodes:
  54. if may_depend_on(inp, b, search_depth - 1):
  55. return True
  56. return False
  57. def are_nodes_independent(nodes: list[Node]):
  58. """
  59. Check if all of the given nodes are pairwise-data independent.
  60. Arguments:
  61. nodes: The nodes to check for data dependencies.
  62. Returns:
  63. True if any pair in nodes has a data dependency.
  64. """
  65. # For each pair in nodes:
  66. for i, j in itertools.combinations(nodes, 2):
  67. if may_depend_on(i, j) or may_depend_on(j, i):
  68. return False
  69. return True
  70. def merge_matmul(in_mod: torch.nn.Module):
  71. """
  72. A graph transformation that merges matrix multiplication operations that share the same right-hand
  73. side operand into one large matrix multiplication.
  74. ____ _________ _________
  75. ---- | | | | M| A * C |
  76. M| A | T| B | * K| C | = |---------|
  77. ---- , | | | | T| B * C |
  78. K ---- --------- ---------
  79. K R R
  80. """
  81. gm = symbolic_trace(in_mod)
  82. rhs_users: dict[Node, list[Node]] = {}
  83. lhs_users: dict[Node, list[Node]] = {}
  84. # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
  85. # the matmul of which they are the LHS/RHS.
  86. for node in gm.graph.nodes:
  87. if node.op != "call_function" or node.target is not torch.matmul:
  88. continue
  89. lhs, rhs = node.args
  90. # TODO: Properly handle aliasing caused by get_attr. For now,
  91. # use the attribute name as the operand if the node is a
  92. # get_attr.
  93. lhs = lhs.target if lhs.op == "get_attr" else lhs
  94. rhs = rhs.target if rhs.op == "get_attr" else rhs
  95. lhs_users.setdefault(lhs, []).append(node)
  96. rhs_users.setdefault(rhs, []).append(node)
  97. for rhs, mms in rhs_users.items():
  98. # There must be at least matmuls for a merge to make sense.
  99. if len(mms) < 2:
  100. continue
  101. # All matmuls must not depend on each other directly or indirectly
  102. # in order for the merge to be possible.
  103. if not are_nodes_independent(mms):
  104. continue
  105. lhs_vals = [mm.args[0] for mm in mms]
  106. # Merge the matmul.
  107. # Collect a list of LHS operands and the single RHS operand.
  108. lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
  109. rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
  110. # Concatenate all the LHS operands.
  111. merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
  112. # Multiply the concatenated LHS operands with the one RHS. This will produce
  113. # the same results as all the individual matmuls involving rhs in the original graph,
  114. # but they will all be concatenated together.
  115. merge_mm = gm.graph.call_function(
  116. torch.matmul,
  117. (
  118. merge_mm_cat,
  119. rhs,
  120. ),
  121. {},
  122. )
  123. # Split the result of the merged matmul using the shapes of the LHS operands
  124. # to ascertain how large each chunk should be.
  125. merge_mm_split = gm.graph.call_function(
  126. split_result_tensors, (merge_mm, lhs), {}
  127. )
  128. merge_mm_res = [
  129. gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
  130. for out in range(len(lhs))
  131. ]
  132. # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
  133. for old, new in zip(mms, merge_mm_res):
  134. old.replace_all_uses_with(new)
  135. gm.graph.erase_node(old)
  136. # All of the new nodes created above were inserted at the end, so we need to sort
  137. # the nodes topologically to make sure all definitions precede uses.
  138. legalize_graph(gm)
  139. gm.recompile()
  140. gm.graph.lint()
  141. return gm