fusion_attention_sam2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_base import Fusion
  8. from fusion_utils import NumpyHelper
  9. from onnx import NodeProto, helper, numpy_helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionMultiHeadAttentionSam2(Fusion):
  13. """
  14. Fuse MultiHeadAttention subgraph of Segment Anything v2 (SAM2).
  15. """
  16. def __init__(
  17. self,
  18. model: OnnxModel,
  19. hidden_size: int,
  20. num_heads: int,
  21. ):
  22. super().__init__(model, "MultiHeadAttention", ["LayerNormalization"])
  23. self.hidden_size = hidden_size
  24. self.num_heads = num_heads
  25. # Flags to show warning only once
  26. self.num_heads_warning = True
  27. self.hidden_size_warning = True
  28. def get_decoder_num_heads(self, reshape_q: NodeProto) -> int:
  29. """Detect num_heads from a reshape node.
  30. Args:
  31. reshape_q (NodeProto): reshape node for Q
  32. Returns:
  33. int: num_heads, or 0 if not found
  34. """
  35. num_heads = 0
  36. # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
  37. shape_value = self.model.get_constant_value(reshape_q.input[1])
  38. if shape_value is not None:
  39. if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [4]:
  40. num_heads = int(shape_value[2])
  41. if isinstance(num_heads, int) and num_heads > 0:
  42. return num_heads
  43. return 0
  44. def get_encoder_num_heads(self, reshape_in: NodeProto) -> int:
  45. """Detect num_heads from a reshape node.
  46. Args:
  47. reshape_q (NodeProto): reshape node for Q
  48. Returns:
  49. int: num_heads, or 0 if not found
  50. """
  51. num_heads = 0
  52. shape_value = self.model.get_constant_value(reshape_in.input[1])
  53. if shape_value is not None:
  54. if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [5]:
  55. num_heads = int(shape_value[3])
  56. else:
  57. concat_shape = self.model.match_parent(reshape_in, "Concat", 1)
  58. if concat_shape is not None and len(concat_shape.input) == 5:
  59. # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
  60. shape_value = self.model.get_constant_value(concat_shape.input[3])
  61. if shape_value is not None:
  62. if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [1]:
  63. num_heads = int(shape_value[0])
  64. if isinstance(num_heads, int) and num_heads > 0:
  65. return num_heads
  66. return 0
  67. def get_hidden_size(self, layernorm_node):
  68. """Detect hidden_size from LayerNormalization node.
  69. Args:
  70. layernorm_node (NodeProto): LayerNormalization node before Q, K and V
  71. Returns:
  72. int: hidden_size, or 0 if not found
  73. """
  74. layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
  75. if layernorm_bias:
  76. return NumpyHelper.to_array(layernorm_bias).shape[0]
  77. return 0
  78. def get_num_heads_and_hidden_size(
  79. self, reshape_q: NodeProto, layernorm_node: NodeProto, is_encoder: bool = False
  80. ) -> tuple[int, int]:
  81. """Detect num_heads and hidden_size.
  82. Args:
  83. reshape_q (NodeProto): reshape node for Q
  84. layernorm_node (NodeProto): LayerNormalization node before Q, K, V
  85. Returns:
  86. Tuple[int, int]: num_heads and hidden_size
  87. """
  88. if is_encoder:
  89. num_heads = self.get_encoder_num_heads(reshape_q)
  90. else:
  91. num_heads = self.get_decoder_num_heads(reshape_q)
  92. if num_heads <= 0:
  93. num_heads = self.num_heads # Fall back to user specified value
  94. if self.num_heads > 0 and num_heads != self.num_heads:
  95. if self.num_heads_warning:
  96. logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
  97. self.num_heads_warning = False # Do not show the warning more than once
  98. hidden_size = self.get_hidden_size(layernorm_node)
  99. if hidden_size <= 0:
  100. hidden_size = self.hidden_size # Fall back to user specified value
  101. if self.hidden_size > 0 and hidden_size != self.hidden_size:
  102. if self.hidden_size_warning:
  103. logger.warning(
  104. f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
  105. )
  106. self.hidden_size_warning = False # Do not show the warning more than once
  107. return num_heads, hidden_size
  108. def create_attention_node(
  109. self,
  110. q_matmul: NodeProto,
  111. q_add: NodeProto,
  112. k_matmul: NodeProto,
  113. k_add: NodeProto,
  114. v_matmul: NodeProto,
  115. v_add: NodeProto,
  116. num_heads: int,
  117. hidden_size: int,
  118. output: str,
  119. ) -> NodeProto | None:
  120. """Create an Attention node.
  121. Args:
  122. q_matmul (NodeProto): MatMul node in fully connection for Q
  123. q_add (NodeProto): Add bias node in fully connection for Q
  124. k_matmul (NodeProto): MatMul node in fully connection for K
  125. k_add (NodeProto): Add bias node in fully connection for K
  126. v_matmul (NodeProto): MatMul node in fully connection for V
  127. v_add (NodeProto): Add bias node in fully connection for V
  128. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  129. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
  130. output (str): output name
  131. Returns:
  132. Union[NodeProto, None]: the node created or None if failed.
  133. """
  134. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  135. logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
  136. return None
  137. q_weight = self.model.get_initializer(q_matmul.input[1])
  138. k_weight = self.model.get_initializer(k_matmul.input[1])
  139. v_weight = self.model.get_initializer(v_matmul.input[1])
  140. if not (q_weight and k_weight and v_weight):
  141. return None
  142. qw = NumpyHelper.to_array(q_weight)
  143. kw = NumpyHelper.to_array(k_weight)
  144. vw = NumpyHelper.to_array(v_weight)
  145. logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
  146. attention_node_name = self.model.create_node_name("MultiHeadAttention")
  147. attention_inputs = [
  148. q_add.output[0],
  149. k_add.output[0],
  150. v_add.output[0],
  151. ]
  152. attention_node = helper.make_node(
  153. "MultiHeadAttention",
  154. inputs=attention_inputs,
  155. outputs=[output],
  156. name=attention_node_name,
  157. )
  158. attention_node.domain = "com.microsoft"
  159. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  160. counter_name = "MultiHeadAttention ({})".format("cross attention")
  161. self.increase_counter(counter_name)
  162. return attention_node
  163. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  164. if self.fuse_sam_encoder_pattern(normalize_node, input_name_to_nodes, output_name_to_node):
  165. return
  166. match_qkv = self.match_attention_subgraph(normalize_node)
  167. if match_qkv is None:
  168. if normalize_node.input[0] not in output_name_to_node:
  169. return
  170. skip_add = output_name_to_node[normalize_node.input[0]]
  171. if skip_add.op_type != "Add":
  172. return
  173. match_qkv = self.match_attention_subgraph(skip_add)
  174. if match_qkv is None:
  175. return
  176. reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v = match_qkv
  177. attention_last_node = reshape_qkv
  178. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, False)
  179. if q_num_heads <= 0:
  180. logger.debug("fuse_attention: failed to detect num_heads")
  181. return
  182. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  183. new_node = self.create_attention_node(
  184. matmul_q,
  185. add_q,
  186. matmul_k,
  187. add_k,
  188. matmul_v,
  189. add_v,
  190. q_num_heads,
  191. q_hidden_size,
  192. output=attention_last_node.output[0],
  193. )
  194. if new_node is None:
  195. return
  196. self.nodes_to_add.append(new_node)
  197. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  198. self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
  199. # Use prune graph to remove nodes since they are shared by all attention nodes.
  200. self.prune_graph = True
  201. def match_attention_subgraph(self, node_after_output_projection):
  202. """Match Q, K and V paths exported by PyTorch 2.*"""
  203. qkv_nodes = self.model.match_parent_path(
  204. node_after_output_projection,
  205. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  206. [None, None, None, 0, 0],
  207. )
  208. if qkv_nodes is None:
  209. return None
  210. (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
  211. v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
  212. if v_nodes is None:
  213. logger.debug("fuse_attention: failed to match v path")
  214. return None
  215. (_, _, add_v, matmul_v) = v_nodes
  216. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
  217. if qk_nodes is not None:
  218. (_softmax_qk, matmul_qk) = qk_nodes
  219. else:
  220. logger.debug("fuse_attention: failed to match qk path")
  221. return None
  222. q_nodes = self.model.match_parent_path(
  223. matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, None, 0, 0, None]
  224. )
  225. if q_nodes is None:
  226. logger.debug("fuse_attention: failed to match q path")
  227. return None
  228. (mul_q, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes
  229. k_nodes = self.model.match_parent_path(
  230. matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [1, None, 0, 0, None]
  231. )
  232. if k_nodes is None:
  233. logger.debug("fuse_attention: failed to match k path")
  234. return None
  235. (_mul_k, _, _, add_k, matmul_k) = k_nodes
  236. # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
  237. mul_q_nodes = self.model.match_parent_path(
  238. mul_q,
  239. ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
  240. [None, 0, 1, 0, 0, 0, 0, 0],
  241. )
  242. if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
  243. logger.debug("fuse_attention: failed to match mul_q path")
  244. return None
  245. return reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v
  246. # --------------------------------------------------------
  247. # The following are for SAM encoder
  248. # --------------------------------------------------------
  249. def fuse_sam_encoder_pattern(self, normalize_node, input_name_to_nodes, output_name_to_node) -> bool:
  250. # SAM encoder attention layer pattern:
  251. # Add -----------+
  252. # | |
  253. # LayerNorm |
  254. # | |
  255. # Reshape |
  256. # | |
  257. # Transpose |
  258. # | |
  259. # MatMul |
  260. # | |
  261. # Add |
  262. # | |
  263. # Reshape |
  264. # | |
  265. # Split |
  266. # | |
  267. # Self Attention subgraph |
  268. # | |
  269. # Reshape |
  270. # | |
  271. # Transpose |
  272. # | |
  273. # Reshape |
  274. # | |
  275. # Add ----------+
  276. # |
  277. # LayerNorm (starts from here)
  278. nodes = self.model.match_parent_path(
  279. normalize_node,
  280. ["Add", "Reshape", "Transpose", "Reshape"],
  281. [0, None, 0, 0],
  282. )
  283. if nodes is None:
  284. nodes = self.model.match_parent_path(
  285. normalize_node,
  286. ["Add", "Slice", "Slice", "Reshape", "Transpose", "Reshape"],
  287. [0, None, 0, 0, 0, 0],
  288. )
  289. if nodes is None:
  290. nodes = self.model.match_parent_path(
  291. normalize_node,
  292. ["Add"],
  293. [0],
  294. )
  295. if nodes is None:
  296. return False
  297. node_after_output_projection = nodes[-1]
  298. matched_sdpa = self.match_sam_encoder_attention_subgraph(
  299. node_after_output_projection, input_index=1 if len(nodes) == 1 else None
  300. )
  301. if matched_sdpa is None:
  302. return False
  303. reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v = matched_sdpa
  304. # B, S, N, H => B, N, S, H
  305. permutation_q = OnnxModel.get_node_attribute(transpose_q, "perm")
  306. if (not isinstance(permutation_q, list)) or permutation_q != [0, 2, 1, 3]:
  307. return False
  308. # B, S, N, H => B, N, H, S
  309. permutation_k = OnnxModel.get_node_attribute(transpose_k, "perm")
  310. if (not isinstance(permutation_k, list)) or permutation_k != [0, 2, 3, 1]:
  311. return False
  312. # B, S, N, H => B, N, S, H
  313. permutation_v = OnnxModel.get_node_attribute(transpose_v, "perm")
  314. if (not isinstance(permutation_v, list)) or permutation_v != [0, 2, 1, 3]:
  315. return False
  316. input_projection_nodes = self.model.match_parent_path(
  317. split_qkv,
  318. ["Reshape", "Add", "MatMul"],
  319. [0, 0, None],
  320. )
  321. if input_projection_nodes is None:
  322. return False
  323. reshape_in, add_in, matmul_in = input_projection_nodes
  324. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_in, normalize_node, True)
  325. if q_num_heads <= 0:
  326. logger.debug("fuse_attention: failed to detect num_heads")
  327. return False
  328. # Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator.
  329. new_dims_name = "bsnh_to_bsd_reshape_dims"
  330. new_dims = self.model.get_initializer(new_dims_name)
  331. if new_dims is None:
  332. new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name)
  333. self.model.add_initializer(new_dims, self.this_graph_name)
  334. reshape_q_name = self.model.create_node_name("Reshape")
  335. reshape_q = helper.make_node(
  336. "Reshape",
  337. inputs=[transpose_q.input[0], new_dims_name],
  338. outputs=[transpose_q.input[0] + "_BSD"],
  339. name=reshape_q_name,
  340. )
  341. self.nodes_to_add.append(reshape_q)
  342. self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name
  343. # Reuse the transpose_q node to transpose K from BSNH to BNSH. Here we update the input and output of the node.
  344. transpose_k_bnsh = transpose_q
  345. transpose_k_bnsh.input[0] = transpose_k.input[0]
  346. transpose_k_bnsh.output[0] = transpose_k.input[0] + "_BNSH"
  347. logger.debug(f"Found MHA: {q_num_heads=} {q_hidden_size=}")
  348. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  349. new_node = self.create_mha_node(
  350. reshape_q,
  351. transpose_k_bnsh,
  352. transpose_v,
  353. q_num_heads,
  354. )
  355. if new_node is None:
  356. return False
  357. # Update the input of the next node that consumes the output of the MHA.
  358. assert len(self.model.get_children(transpose_out, input_name_to_nodes)) == 1
  359. reshape_out.input[0] = new_node.output[0]
  360. self.nodes_to_add.append(new_node)
  361. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  362. self.nodes_to_remove.extend([transpose_out])
  363. # Use prune graph to remove nodes since they are shared by all attention nodes.
  364. self.prune_graph = True
  365. return True
  366. def match_sam_encoder_attention_subgraph(self, node_after_output_projection, input_index=None):
  367. """Match SDPA pattern in SAM2 enconder.*"""
  368. # nodes of output projection and the second MatMul in SDPA.
  369. out_nodes = self.model.match_parent_path(
  370. node_after_output_projection,
  371. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  372. [input_index, None, None, 0, 0],
  373. )
  374. if out_nodes is None:
  375. return None
  376. (_, _, reshape_out, transpose_out, matmul_qk_v) = out_nodes
  377. # Split and Reshape is for packed QKV
  378. v_nodes = self.model.match_parent_path(matmul_qk_v, ["Transpose", "Squeeze", "Split", "Reshape"], [1, 0, 0, 0])
  379. if v_nodes is None:
  380. logger.debug("failed to match v path")
  381. return None
  382. (transpose_v, _, split_qkv, reshape_qkv) = v_nodes
  383. qk_nodes = self.model.match_parent_path(matmul_qk_v, ["Softmax", "MatMul"], [0, 0])
  384. if qk_nodes is not None:
  385. (_softmax_qk, matmul_qk) = qk_nodes
  386. else:
  387. logger.debug("failed to match qk path")
  388. return None
  389. q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [0, None, 0, 0])
  390. if q_nodes is None:
  391. q_nodes = self.model.match_parent_path(
  392. matmul_qk,
  393. ["Mul", "Transpose", "Reshape", "Transpose", "MaxPool", "Transpose", "Reshape", "Squeeze", "Split"],
  394. [0, None, 0, 0, 0, 0, 0, 0, 0],
  395. )
  396. if q_nodes is None:
  397. logger.debug("failed to match q path")
  398. return None
  399. if q_nodes[-1] != split_qkv:
  400. return None
  401. transpose_q = q_nodes[1]
  402. k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [1, None, 0, 0])
  403. if k_nodes is None:
  404. logger.debug("failed to match k path")
  405. return None
  406. if k_nodes[-1] != split_qkv:
  407. return None
  408. (mul_k, transpose_k, _squeeze_k, _) = k_nodes
  409. return reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v
  410. def create_mha_node(
  411. self,
  412. reshape_q: NodeProto,
  413. transpose_k: NodeProto,
  414. transpose_v: NodeProto,
  415. num_heads: int,
  416. ) -> NodeProto:
  417. """Create a MultiHeadAttention node for SAM2 encoder.
  418. Args:
  419. reshape_q (NodeProto): Reshape node for Q, output is 3D BxSxNH format
  420. transpose_k (NodeProto): Transpose node for K, output is BNSH format
  421. transpose_v (NodeProto): Transpose node for V, output is BNSH format
  422. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  423. Returns:
  424. NodeProto: the MultiHeadAttention node created.
  425. """
  426. attention_node_name = self.model.create_node_name("MultiHeadAttention")
  427. inputs = [
  428. reshape_q.output[0],
  429. transpose_k.output[0],
  430. transpose_v.output[0],
  431. ]
  432. # Create a new output name since the shape is 3D, which is different from the original output shape (4D).
  433. output = attention_node_name + "_out"
  434. attention_node = helper.make_node(
  435. "MultiHeadAttention",
  436. inputs=inputs,
  437. outputs=[output],
  438. name=attention_node_name,
  439. )
  440. attention_node.domain = "com.microsoft"
  441. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  442. counter_name = "MultiHeadAttention ({})".format("self attention")
  443. self.increase_counter(counter_name)
  444. return attention_node