fusion_attention_clip.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_attention import AttentionMask, FusionAttention
  7. from fusion_options import AttentionMaskFormat
  8. from onnx import NodeProto
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionAttentionClip(FusionAttention):
  12. """
  13. Fuse Attention subgraph of Clip into one Attention node.
  14. """
  15. def __init__(
  16. self,
  17. model: OnnxModel,
  18. hidden_size: int,
  19. num_heads: int,
  20. ):
  21. attention_mask = AttentionMask(model)
  22. attention_mask.mask_format = AttentionMaskFormat.NoMask
  23. super().__init__(
  24. model,
  25. hidden_size,
  26. num_heads,
  27. attention_mask,
  28. use_multi_head_attention=False,
  29. search_op_types=["SkipLayerNormalization"],
  30. )
  31. def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]:
  32. """Detect num_heads and hidden_size for ONNX model from MiDaS
  33. Args:
  34. reshape_q (NodeProto): reshape node for q
  35. Returns:
  36. Tuple[int, int]: num_heads and hidden_size
  37. """
  38. concat = self.model.match_parent(reshape_q, "Concat", 1)
  39. if concat is None or len(concat.input) != 4:
  40. return self.num_heads, self.hidden_size
  41. # The shape is a tensor like [?, ?, num_heads, head_size]
  42. num_head_value = self.model.get_constant_value(concat.input[2])
  43. if num_head_value is None:
  44. return self.num_heads, self.hidden_size # Fall back to user specified value
  45. if len(num_head_value) != 1 or num_head_value[0] <= 0:
  46. return self.num_heads, self.hidden_size # Fall back to user specified value
  47. num_heads = num_head_value[0]
  48. head_size_value = self.model.get_constant_value(concat.input[3])
  49. if head_size_value is None:
  50. return self.num_heads, self.hidden_size # Fall back to user specified value
  51. if len(head_size_value) != 1 or head_size_value[0] <= 0:
  52. return self.num_heads, self.hidden_size # Fall back to user specified value
  53. head_size = head_size_value[0]
  54. hidden_size = num_heads * head_size
  55. if self.num_heads > 0 and num_heads != self.num_heads:
  56. if self.num_heads_warning:
  57. logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
  58. self.num_heads_warning = False # Do not show the warning more than once
  59. if self.hidden_size > 0 and hidden_size != self.hidden_size:
  60. if self.hidden_size_warning:
  61. logger.warning(
  62. f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
  63. )
  64. self.hidden_size_warning = False # Do not show the warning more than once
  65. return num_heads, hidden_size
  66. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  67. skip_input_index = None
  68. node_before_layer_norm = None
  69. for i in [1, 0]:
  70. parent = self.model.match_parent(normalize_node, "SkipLayerNormalization", i)
  71. if parent is not None:
  72. skip_input_index = i
  73. node_before_layer_norm = parent
  74. root_input = None
  75. if node_before_layer_norm is not None:
  76. root_input = node_before_layer_norm.output[0]
  77. else:
  78. # Deal with the first attention after the embedding layer.
  79. for i in [0, 1]:
  80. node_before_layer_norm = None
  81. node_before_layer_norm_1 = self.model.match_parent(normalize_node, "Add", i)
  82. node_before_layer_norm_2 = self.model.match_parent(normalize_node, "LayerNormalization", i)
  83. if node_before_layer_norm_1 is not None:
  84. # Add -----------+
  85. # | |
  86. # LayerNorm |
  87. # | |
  88. # LayerNorm |
  89. # | |
  90. # Attention subgraph |
  91. # | |
  92. # SkipLayerNorm ------+
  93. node_before_layer_norm = node_before_layer_norm_1
  94. elif node_before_layer_norm_2 is not None:
  95. # Add
  96. # |
  97. # LayerNorm --------+
  98. # | |
  99. # LayerNorm |
  100. # | |
  101. # Attention subgraph |
  102. # | |
  103. # SkipLayerNorm ------+
  104. node_before_layer_norm = node_before_layer_norm_2
  105. if node_before_layer_norm is None:
  106. continue
  107. child = self.model.find_first_child_by_type(
  108. node_before_layer_norm,
  109. "LayerNormalization",
  110. input_name_to_nodes,
  111. False,
  112. )
  113. if child is None:
  114. continue
  115. root_input = child.output[0]
  116. skip_input_index = i
  117. break
  118. if skip_input_index is None:
  119. return
  120. qkv_nodes = self.model.match_parent_path(
  121. normalize_node,
  122. ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  123. [1 - skip_input_index, None, None, 0, 0, 0],
  124. )
  125. if qkv_nodes is None:
  126. qkv_nodes = self.model.match_parent_path(
  127. normalize_node,
  128. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  129. [1, None, 0, 0, 0],
  130. )
  131. if qkv_nodes is None:
  132. logger.debug("fuse_attention: failed to match qkv path")
  133. return
  134. reshape_qkv, transpose_qkv, matmul_qkv = (
  135. qkv_nodes[2],
  136. qkv_nodes[3],
  137. qkv_nodes[-1],
  138. )
  139. v_nodes = self.model.match_parent_path(
  140. matmul_qkv,
  141. ["Reshape", "Transpose", "Reshape", "Add", "MatMul"],
  142. [1, 0, 0, 0, None],
  143. )
  144. if v_nodes is None:
  145. v_nodes = self.model.match_parent_path(
  146. matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]
  147. )
  148. if v_nodes is None:
  149. logger.debug("fuse_attention: failed to match v path")
  150. return
  151. add_v, matmul_v = v_nodes[-2], v_nodes[-1]
  152. causal_mask_input_index = None
  153. add_mask = None
  154. add_mask_indices = []
  155. qk_nodes = self.model.match_parent_path(
  156. matmul_qkv,
  157. ["Softmax", "Reshape", "Add", "Reshape", "MatMul"],
  158. [0, 0, 0, None, 0],
  159. return_indice=add_mask_indices,
  160. )
  161. if qk_nodes is None:
  162. qk_nodes = self.model.match_parent_path(
  163. matmul_qkv,
  164. ["Softmax", "MatMul"],
  165. [0, 0],
  166. )
  167. if qk_nodes is None:
  168. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
  169. if qk_nodes is not None:
  170. add_mask = qk_nodes[1]
  171. else:
  172. # If attention mask is not used, we can still match the qk path.
  173. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
  174. if qk_nodes is None:
  175. # Cast nodes are added in the model for fp16.
  176. qk_nodes = self.model.match_parent_path(
  177. matmul_qkv,
  178. ["Cast", "Cast", "Softmax", "Add", "Mul", "MatMul"],
  179. [0, 0, 0, 0, 0, 0],
  180. )
  181. if qk_nodes is not None:
  182. add_mask = qk_nodes[3]
  183. else:
  184. # If attention mask is not used, we can still match the qk path.
  185. qk_nodes = self.model.match_parent_path(
  186. matmul_qkv,
  187. ["Cast", "Cast", "Softmax", "Mul", "MatMul"],
  188. [0, 0, 0, 0, 0],
  189. )
  190. if qk_nodes is None:
  191. logger.debug("fuse_attention: failed to match qk path")
  192. return
  193. else:
  194. assert len(add_mask_indices) == 1
  195. causal_mask_input_index = 1 - add_mask_indices[0]
  196. add_mask = qk_nodes[2]
  197. matmul_qk = qk_nodes[-1]
  198. q_nodes = self.model.match_parent_path(
  199. matmul_qk,
  200. ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"],
  201. [0, 0, 0, 0, None, None],
  202. )
  203. if q_nodes is None:
  204. q_nodes = self.model.match_parent_path(
  205. matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None]
  206. )
  207. if q_nodes is None:
  208. logger.debug("fuse_attention: failed to match q path")
  209. return
  210. reshape_q = q_nodes[1]
  211. else:
  212. reshape_q = q_nodes[2]
  213. add_q, matmul_q = q_nodes[-2], q_nodes[-1]
  214. k_nodes = self.model.match_parent_path(
  215. matmul_qk,
  216. ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"],
  217. [1, 0, 0, 0, 0, None],
  218. )
  219. if k_nodes is None:
  220. k_nodes = self.model.match_parent_path(
  221. matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]
  222. )
  223. if k_nodes is None:
  224. logger.debug("fuse_attention: failed to match k path")
  225. return
  226. add_k, matmul_k = k_nodes[-2], k_nodes[-1]
  227. if matmul_q.input[0] != root_input or matmul_k.input[0] != root_input or matmul_v.input[0] != root_input:
  228. logger.debug("fuse_attention: expect to have same input to q, k and v matmul")
  229. return
  230. num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
  231. if num_heads <= 0 or hidden_size <= 0:
  232. logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
  233. return
  234. attention_last_node = reshape_qkv
  235. add_qk = ""
  236. causal_mask_nodes_1 = None
  237. causal_mask_nodes_2 = None
  238. if add_mask is not None:
  239. if add_mask.input[1] == "attention_mask":
  240. add_qk = add_mask.input[1]
  241. else:
  242. # 4D Add after Q x K'
  243. add_qk_nodes = self.model.match_parent_path(
  244. add_mask,
  245. [
  246. "Where",
  247. "Sub",
  248. "Cast",
  249. "Expand",
  250. "Unsqueeze",
  251. "Unsqueeze",
  252. "Reshape",
  253. "Reshape",
  254. "Cast",
  255. ],
  256. [1, 2, 1, 0, 0, 0, 0, 0, 0],
  257. )
  258. if add_qk_nodes is not None:
  259. add_qk = add_mask.input[1]
  260. else:
  261. # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
  262. # of computing causal mask.
  263. causal_mask_nodes_1 = self.model.match_parent_path(
  264. add_mask,
  265. ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
  266. [causal_mask_input_index, 0, 0, 0, 0, 0],
  267. )
  268. # If the model is exported with batch_size == 1, there is no Concat node
  269. causal_mask_nodes_2 = self.model.match_parent_path(
  270. add_mask,
  271. ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
  272. [causal_mask_input_index, 0, 0, 0, 0],
  273. )
  274. if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None:
  275. logger.debug("fuse_attention: failed to match causal mask subgraph")
  276. return
  277. new_node = self.create_attention_node(
  278. mask_index=None,
  279. q_matmul=matmul_q,
  280. k_matmul=matmul_k,
  281. v_matmul=matmul_v,
  282. q_add=add_q,
  283. k_add=add_k,
  284. v_add=add_v,
  285. num_heads=num_heads,
  286. hidden_size=hidden_size,
  287. first_input=root_input,
  288. output=attention_last_node.output[0],
  289. add_qk_str=add_qk,
  290. scale=None,
  291. causal=(causal_mask_nodes_1 is not None) or (causal_mask_nodes_2 is not None),
  292. )
  293. if new_node is None:
  294. logger.debug("fuse_attention: failed to create fused node")
  295. return
  296. self.nodes_to_add.append(new_node)
  297. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  298. self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
  299. # Use prune graph to remove nodes since they are shared by all attention nodes.
  300. self.prune_graph = True