onnx_model_tnlr.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_attention import AttentionMask, FusionAttention
  7. from fusion_utils import NumpyHelper
  8. from onnx import NodeProto, helper
  9. from onnx_model import OnnxModel
  10. from onnx_model_bert import BertOnnxModel
  11. logger = logging.getLogger(__name__)
  12. class FusionTnlrAttention(FusionAttention):
  13. """
  14. Fuse TNLR Attention subgraph into one Attention node.
  15. TNLR Attention has extra addition after qk nodes and adopts [S, B, NH] as I/O shape.
  16. """
  17. def __init__(
  18. self,
  19. model: OnnxModel,
  20. hidden_size: int,
  21. num_heads: int,
  22. attention_mask: AttentionMask,
  23. ):
  24. super().__init__(model, hidden_size, num_heads, attention_mask)
  25. def create_attention_node(
  26. self,
  27. mask_index: str,
  28. matmul: NodeProto,
  29. add: NodeProto,
  30. num_heads: int,
  31. hidden_size: int,
  32. input: str,
  33. output: str,
  34. add_qk_str: str,
  35. ) -> NodeProto | None:
  36. assert num_heads > 0
  37. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  38. logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
  39. return None
  40. weight = self.model.get_initializer(matmul.input[1])
  41. bias = self.model.get_initializer(add.input[1]) or self.model.get_initializer(add.input[0])
  42. if weight is None or bias is None:
  43. return None
  44. qkv_weight = NumpyHelper.to_array(weight)
  45. qkv_bias = NumpyHelper.to_array(bias)
  46. attention_node_name = self.model.create_node_name("Attention")
  47. tensor_dtype = weight.data_type
  48. np_type = helper.tensor_dtype_to_np_dtype(tensor_dtype)
  49. weight = helper.make_tensor(
  50. name=attention_node_name + "_qkv_weight",
  51. data_type=tensor_dtype,
  52. dims=[hidden_size, 3 * hidden_size],
  53. vals=qkv_weight.astype(np_type).tobytes(),
  54. raw=True,
  55. )
  56. self.model.add_initializer(weight, self.this_graph_name)
  57. bias = helper.make_tensor(
  58. name=attention_node_name + "_qkv_bias",
  59. data_type=tensor_dtype,
  60. dims=[3 * hidden_size],
  61. vals=qkv_bias.astype(np_type).tobytes(),
  62. raw=True,
  63. )
  64. self.model.add_initializer(bias, self.this_graph_name)
  65. attention_inputs = [
  66. input,
  67. attention_node_name + "_qkv_weight",
  68. attention_node_name + "_qkv_bias",
  69. ]
  70. if mask_index is not None:
  71. attention_inputs.append(mask_index)
  72. else:
  73. attention_inputs.append("")
  74. if add_qk_str is not None:
  75. attention_inputs.append("")
  76. attention_inputs.append(add_qk_str)
  77. attention_node = helper.make_node(
  78. "Attention",
  79. inputs=attention_inputs,
  80. outputs=[output],
  81. name=attention_node_name,
  82. )
  83. attention_node.domain = "com.microsoft"
  84. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  85. return attention_node
  86. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  87. # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
  88. # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
  89. start_node = normalize_node
  90. if normalize_node.op_type != "SkipLayerNormalization":
  91. return
  92. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  93. qkv_nodes = self.model.match_parent_path(
  94. start_node,
  95. ["Where", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  96. [1, 1, 1, 0, 0, 0],
  97. )
  98. if qkv_nodes is not None:
  99. (_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
  100. else:
  101. return
  102. other_inputs = []
  103. for _i, input in enumerate(start_node.input):
  104. if input not in output_name_to_node:
  105. continue
  106. if input == qkv_nodes[0].output[0]:
  107. continue
  108. other_inputs.append(input)
  109. if len(other_inputs) != 1:
  110. return
  111. root_input = other_inputs[0]
  112. v_nodes = self.model.match_parent_path(
  113. matmul_qkv,
  114. ["Transpose", "Reshape", "Slice", "Add", "MatMul"],
  115. [1, 0, 0, 0, 1],
  116. )
  117. if v_nodes is None:
  118. return
  119. (_, _, _, add, matmul) = v_nodes
  120. upper_nodes = self.model.match_parent_path(matmul, ["Transpose"], [0])
  121. transpose = upper_nodes[0]
  122. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
  123. if qk_nodes is None:
  124. return
  125. (_, add_qk, matmul_qk) = qk_nodes
  126. q_nodes = self.model.match_parent_path(
  127. matmul_qk,
  128. ["Mul", "Transpose", "Reshape", "Slice", "Add", "MatMul"],
  129. [0, 0, 0, 0, 0, 1],
  130. )
  131. if q_nodes is None:
  132. return
  133. add = q_nodes[-2]
  134. matmul = q_nodes[-1]
  135. k_nodes = self.model.match_parent_path(
  136. matmul_qk,
  137. ["Transpose", "Reshape", "Slice", "Add", "MatMul"],
  138. [1, 0, 0, 0, 1],
  139. )
  140. if k_nodes is None:
  141. return
  142. add = k_nodes[-2]
  143. matmul = k_nodes[-1]
  144. relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0])
  145. if relative_position_bias_nodes is None:
  146. return
  147. if matmul.input[0] == root_input:
  148. mask_index = None
  149. attention_last_node = reshape_qkv
  150. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  151. # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
  152. new_node = self.create_attention_node(
  153. mask_index,
  154. matmul,
  155. add,
  156. self.num_heads,
  157. self.hidden_size,
  158. root_input,
  159. attention_last_node.output[0],
  160. relative_position_bias_nodes[0].input[0],
  161. )
  162. if new_node is None:
  163. return
  164. self.nodes_to_add.append(new_node)
  165. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  166. # Add a transpose node after the attention node
  167. back_transpose = helper.make_node(
  168. "Transpose",
  169. ["back_transpose_in_" + new_node.name],
  170. [new_node.output[0]],
  171. "back_transpose_" + new_node.name,
  172. perm=[1, 0, 2],
  173. )
  174. self.model.add_node(back_transpose, self.this_graph_name)
  175. new_node.input[0] = transpose.input[0]
  176. new_node.output[0] = "back_transpose_in_" + new_node.name
  177. self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
  178. self.nodes_to_remove.extend(qk_nodes)
  179. self.nodes_to_remove.extend(q_nodes)
  180. self.nodes_to_remove.extend(k_nodes)
  181. self.nodes_to_remove.extend(v_nodes)
  182. # Use prune graph to remove mask nodes since they are shared by all attention nodes.
  183. # self.nodes_to_remove.extend(mask_nodes)
  184. self.prune_graph = True
  185. class TnlrOnnxModel(BertOnnxModel):
  186. def __init__(self, model, num_heads, hidden_size):
  187. super().__init__(model, num_heads, hidden_size)
  188. self.attention_mask = AttentionMask(self)
  189. self.attention_fusion = FusionTnlrAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
  190. def fuse_attention(self):
  191. self.attention_fusion.apply()