fusion_gpt_attention.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_base import Fusion
  8. from fusion_utils import FusionUtils
  9. from onnx import helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionGptAttentionPastBase(Fusion):
  13. """Base class for GPT Attention Fusion with past state"""
  14. def __init__(self, model: OnnxModel, num_heads: int):
  15. super().__init__(model, "Attention", ["LayerNormalization", "SkipLayerNormalization"], "with past")
  16. self.num_heads = num_heads
  17. self.utils = FusionUtils(model)
  18. self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32
  19. self.mask_filter_value = None
  20. def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node):
  21. # Pattern 1:
  22. # {past}
  23. # / \
  24. # / \
  25. # Gather(axes=0, indices=0) Gather(indices=1)
  26. # | |
  27. # Transpose (perm=0,1,3,2) |
  28. # | |
  29. # Concat_k Concat_v
  30. # | /
  31. # Transpose (perm=0,1,3,2) /
  32. # | /
  33. # Unsqueeze Unsqueeze
  34. # \ /
  35. # \ /
  36. # Concat
  37. # |
  38. # {present}
  39. gather = self.model.get_parent(concat_v, 0, output_name_to_node)
  40. if gather is None or gather.op_type != "Gather":
  41. logger.debug("match_past_pattern_1: expect Gather for past")
  42. return None
  43. if self.model.find_constant_input(gather, 1) != 1:
  44. logger.debug("match_past_pattern_1: expect indices=1 for Gather of past")
  45. return None
  46. past = gather.input[0]
  47. parent = self.model.get_parent(concat_k, 0, output_name_to_node)
  48. if parent and parent.op_type == "Gather":
  49. gather_past_k = parent
  50. else:
  51. past_k_nodes = self.model.match_parent_path(concat_k, ["Transpose", "Gather"], [0, 0])
  52. if past_k_nodes is None:
  53. logger.debug("match_past_pattern_1: failed match Transpose and Gather")
  54. return None
  55. gather_past_k = past_k_nodes[-1]
  56. if self.model.find_constant_input(gather_past_k, 0) != 1:
  57. logger.debug("match_past_pattern_1: expect indices=0 for Gather k of past")
  58. return None
  59. past_k = gather_past_k.input[0]
  60. if past != past_k:
  61. logger.debug("match_past_pattern_1: expect past to be same")
  62. return None
  63. return past
  64. def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
  65. # Pattern 2:
  66. # Split (QKV)
  67. # / | |
  68. # / | +----------------------+
  69. # | |
  70. # | {past} |
  71. # | | |
  72. # Reshape Split Reshape
  73. # | / \ |
  74. # Transpose_k Squeeze Squeeze Transpose_v
  75. # | | \ /
  76. # +------|---+ \ /
  77. # | | \ /
  78. # Concat_k Concat_v
  79. # | |
  80. # Unsqueeze Unsqueeze
  81. # \ /
  82. # Concat
  83. # |
  84. # {present}
  85. #
  86. squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
  87. if squeeze is None or squeeze.op_type != "Squeeze":
  88. logger.debug("match_past_pattern_2: expect Squeeze as parent of concat_v")
  89. return None
  90. split = self.model.get_parent(squeeze, 0, output_name_to_node)
  91. if split is None or split.op_type != "Split":
  92. logger.debug("match_past_pattern_2: expect Split for past path")
  93. return None
  94. opset_version = self.model.get_opset_version()
  95. if opset_version < 13:
  96. if not FusionUtils.check_node_attribute(squeeze, "axes", [0]):
  97. logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
  98. return None
  99. if not FusionUtils.check_node_attribute(split, "split", [1, 1]):
  100. logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
  101. return None
  102. else:
  103. if not self.utils.check_node_input_value(squeeze, 1, [0]):
  104. logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
  105. return None
  106. if not self.utils.check_node_input_value(split, 1, [1, 1]):
  107. logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
  108. return None
  109. if not FusionUtils.check_node_attribute(split, "axis", 0, default_value=0):
  110. logger.debug("match_past_pattern_2: attribute axis of Split are not expected in past path")
  111. return None
  112. past = split.input[0]
  113. past_k_nodes = self.model.match_parent_path(concat_k, ["Squeeze", "Split"], [0, 0])
  114. if past_k_nodes is None:
  115. logger.debug("match_past_pattern_2: failed to match past_k_nodes path")
  116. return None
  117. past_k = past_k_nodes[-1].input[0]
  118. if past != past_k:
  119. logger.info("match_past_pattern_2: expect past to be same")
  120. return None
  121. return past
  122. def match_present(self, concat_v, input_name_to_nodes):
  123. unsqueeze_present_v = self.model.find_first_child_by_type(
  124. concat_v, "Unsqueeze", input_name_to_nodes, recursive=False
  125. )
  126. if not unsqueeze_present_v:
  127. logger.info("expect unsqueeze for present")
  128. return None
  129. concat_present = self.model.find_first_child_by_type(
  130. unsqueeze_present_v, "Concat", input_name_to_nodes, recursive=False
  131. )
  132. if not concat_present:
  133. logger.info("expect concat for present")
  134. return None
  135. present = concat_present.output[0]
  136. return present
  137. def cast_attention_mask(self, input_name):
  138. if input_name in self.casted_attention_mask:
  139. attention_mask_input_name = self.casted_attention_mask[input_name]
  140. elif self.model.find_graph_input(input_name):
  141. casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(input_name)
  142. self.casted_attention_mask[input_name] = attention_mask_input_name
  143. else:
  144. attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(input_name)
  145. self.casted_attention_mask[input_name] = attention_mask_input_name
  146. return attention_mask_input_name
  147. class FusionGptAttention(FusionGptAttentionPastBase):
  148. """
  149. Fuse GPT-2 Attention with past state subgraph into one Attention node.
  150. """
  151. def __init__(self, model: OnnxModel, num_heads: int):
  152. super().__init__(model, num_heads)
  153. def create_attention_node(
  154. self,
  155. fc_weight,
  156. fc_bias,
  157. gemm_qkv,
  158. past,
  159. present,
  160. input,
  161. output,
  162. mask,
  163. is_unidirectional,
  164. ):
  165. attention_node_name = self.model.create_node_name("GptAttention")
  166. attention_node = helper.make_node(
  167. "Attention",
  168. inputs=[input, fc_weight, fc_bias, mask, past],
  169. outputs=[attention_node_name + "_output", present],
  170. name=attention_node_name,
  171. )
  172. attention_node.domain = "com.microsoft"
  173. attention_node.attribute.extend(
  174. [
  175. helper.make_attribute("num_heads", self.num_heads),
  176. helper.make_attribute("unidirectional", 1 if is_unidirectional else 0),
  177. ]
  178. )
  179. if self.mask_filter_value is not None:
  180. attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
  181. matmul_node = helper.make_node(
  182. "MatMul",
  183. inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
  184. outputs=[attention_node_name + "_matmul_output"],
  185. name=attention_node_name + "_matmul",
  186. )
  187. add_node = helper.make_node(
  188. "Add",
  189. inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
  190. outputs=[output],
  191. name=attention_node_name + "_add",
  192. )
  193. self.nodes_to_add.extend([attention_node, matmul_node, add_node])
  194. self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
  195. self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
  196. self.node_name_to_graph_name[add_node.name] = self.this_graph_name
  197. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  198. past = None
  199. present = None
  200. return_indice = []
  201. is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
  202. qkv_nodes = None
  203. if not is_normalize_node_skiplayernorm:
  204. qkv_nodes = self.model.match_parent_path(
  205. normalize_node,
  206. ["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
  207. [0, None, 0, 0, 0, 0, 0],
  208. output_name_to_node=output_name_to_node,
  209. return_indice=return_indice,
  210. )
  211. else:
  212. qkv_nodes = self.model.match_parent_path(
  213. normalize_node,
  214. ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
  215. [None, 0, 0, 0, 0, 0],
  216. output_name_to_node=output_name_to_node,
  217. return_indice=return_indice,
  218. )
  219. if qkv_nodes is None:
  220. return
  221. another_input = None
  222. if not is_normalize_node_skiplayernorm:
  223. (
  224. add_qkv,
  225. reshape_qkv,
  226. gemm_qkv,
  227. reshape_1,
  228. reshape_2,
  229. transpose_qkv,
  230. matmul_qkv,
  231. ) = qkv_nodes
  232. another_input = add_qkv.input[1 - return_indice[0]]
  233. else:
  234. (
  235. reshape_qkv,
  236. gemm_qkv,
  237. reshape_1,
  238. reshape_2,
  239. transpose_qkv,
  240. matmul_qkv,
  241. ) = qkv_nodes
  242. v_nodes = self.model.match_parent_path(matmul_qkv, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
  243. if v_nodes is None:
  244. logger.debug("fuse_attention: failed to match v path")
  245. return
  246. (concat_v, transpose_v, reshape_v, split_fc) = v_nodes
  247. # Try match pattern using Gemm + LayerNormalization
  248. fc_nodes = self.model.match_parent_path(
  249. split_fc,
  250. ["Reshape", "Gemm", "Reshape", "LayerNormalization"],
  251. [0, 0, 0, 0],
  252. output_name_to_node,
  253. )
  254. # Try match pattern using Gemm + SkipLayerNormalization
  255. if fc_nodes is None:
  256. fc_nodes = self.model.match_parent_path(
  257. split_fc,
  258. ["Reshape", "Gemm", "Reshape", "SkipLayerNormalization"],
  259. [0, 0, 0, 0],
  260. output_name_to_node,
  261. )
  262. # Try match pattern using MatMul
  263. if fc_nodes is None:
  264. # LayerNormalization
  265. fc_nodes = self.model.match_parent_path(
  266. split_fc,
  267. ["Add", "MatMul", "LayerNormalization"],
  268. [0, None, 0],
  269. output_name_to_node,
  270. )
  271. # SkipLayerNormalization
  272. if fc_nodes is None:
  273. fc_nodes = self.model.match_parent_path(
  274. split_fc,
  275. ["Add", "MatMul", "SkipLayerNormalization"],
  276. [0, None, 0],
  277. output_name_to_node,
  278. )
  279. if fc_nodes is None:
  280. logger.debug("fuse_attention: failed to match fc path")
  281. return
  282. fc_weight = fc_nodes[1].input[1]
  283. i, _ = self.model.get_constant_input(fc_nodes[0])
  284. fc_bias = fc_nodes[0].input[i]
  285. else:
  286. fc_weight = fc_nodes[1].input[1]
  287. fc_bias = fc_nodes[1].input[2]
  288. layernorm_before_attention = fc_nodes[-1]
  289. # `another_input` will be non-None only if
  290. # (1) SkipLayerNorm fusion wasn't turned ON
  291. # (2) SkipLayerNorm fusion was turned ON but upstream layer's LayerNorm + Add was not
  292. # fused into a SkipLayerNorm. This can happen if the shapes to the Add node are different.
  293. # So, keep the following check if SkipLayerNorm fusion is turned ON or OFF.
  294. if another_input is not None and another_input not in layernorm_before_attention.input:
  295. logger.debug("Upstream Add and (Skip)LayerNormalization shall have one same input")
  296. return
  297. is_unidirectional = True
  298. slice_mask = None
  299. input_mask_nodes = None
  300. concat_k_to_match = None
  301. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0])
  302. if qk_nodes is not None:
  303. (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
  304. mask_nodes = self.model.match_parent_path(
  305. sub_qk,
  306. [
  307. "Mul",
  308. "Sub",
  309. "Slice",
  310. "Slice",
  311. "Unsqueeze",
  312. "Sub",
  313. "Squeeze",
  314. "Slice",
  315. "Shape",
  316. "Div",
  317. ],
  318. [1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
  319. )
  320. if mask_nodes is None:
  321. logger.debug("fuse_attention: failed to match unidirectional mask path")
  322. return
  323. div_mask = mask_nodes[-1]
  324. slice_mask = mask_nodes[3]
  325. if div_qk != div_mask:
  326. logger.debug("fuse_attention: skip since div_qk != div_mask")
  327. return
  328. if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
  329. _, mul_val = self.model.get_constant_input(mask_nodes[0])
  330. if mul_val != -10000:
  331. self.mask_filter_value = -mul_val
  332. else:
  333. # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
  334. i, qk_nodes, _ = self.model.match_parent_paths(
  335. matmul_qkv,
  336. [
  337. (["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0]),
  338. (["Softmax", "Add", "Where", "Div", "MatMul"], [0, 0, None, 1, 0]),
  339. ],
  340. output_name_to_node,
  341. )
  342. if qk_nodes is None:
  343. logger.debug("fuse_attention: failed to match qk nodes")
  344. return
  345. where_qk = qk_nodes[-3]
  346. div_qk = qk_nodes[-2]
  347. matmul_qk = qk_nodes[-1]
  348. if i == 1:
  349. add_qk = qk_nodes[1]
  350. _, input_mask_nodes, _ = self.model.match_parent_paths(
  351. add_qk,
  352. [
  353. (
  354. ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze", "Reshape"],
  355. [None, 0, 1, 0, 0, 0],
  356. ),
  357. (
  358. ["Mul", "Sub", "Unsqueeze", "Unsqueeze", "Reshape"],
  359. [None, 0, 1, 0, 0],
  360. ),
  361. (
  362. ["Mul", "Sub", "Unsqueeze", "Unsqueeze"],
  363. [None, 0, 1, 0],
  364. ), # useless cast and reshape are removed.
  365. ],
  366. output_name_to_node,
  367. )
  368. if input_mask_nodes is None:
  369. logger.debug("fuse_attention: failed to match input attention mask path")
  370. return
  371. if len(input_mask_nodes) > 1 and input_mask_nodes[0].op_type == "Mul":
  372. _, mul_val = self.model.get_constant_input(input_mask_nodes[0])
  373. if mul_val != -10000:
  374. self.mask_filter_value = mul_val
  375. i, mask_nodes, _ = self.model.match_parent_paths(
  376. where_qk,
  377. [
  378. (
  379. ["Cast", "Slice", "Slice", "Unsqueeze", "Sub", "Squeeze", "Slice", "Shape"],
  380. [0, 0, 0, 1, 0, 0, 0, 0],
  381. ),
  382. # For Transformers >= 4.27, causal mask uses torch.bool instead of torch.uint8, so no Cast to bool.
  383. (
  384. ["Slice", "Slice", "Unsqueeze", "Sub", "Squeeze", "Slice", "Shape"],
  385. [0, 0, 1, 0, 0, 0, 0],
  386. ),
  387. ],
  388. output_name_to_node,
  389. )
  390. if mask_nodes is None:
  391. # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep.
  392. logger.debug("fuse_attention: failed to match mask path")
  393. return
  394. slice_mask = mask_nodes[2 if i == 0 else 1]
  395. div_or_concat = self.model.get_parent(mask_nodes[-1], 0, output_name_to_node)
  396. if div_or_concat.op_type == "Div":
  397. div_mask = div_or_concat
  398. if div_qk != div_mask:
  399. logger.debug("fuse_attention: skip since div_qk != div_mask")
  400. return
  401. elif div_or_concat.op_type == "Concat":
  402. concat_k_to_match = div_or_concat
  403. else:
  404. logger.debug("fuse_attention: failed to match mask path")
  405. # Validate that the mask data is either lower triangular (unidirectional) or all ones
  406. mask_data = self.model.get_constant_value(slice_mask.input[0])
  407. if not (
  408. isinstance(mask_data, np.ndarray)
  409. and len(mask_data.shape) == 4
  410. and mask_data.shape[:2] == (1, 1)
  411. and mask_data.shape[2] == mask_data.shape[3]
  412. ):
  413. logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW")
  414. return
  415. if np.allclose(mask_data, np.ones_like(mask_data)):
  416. is_unidirectional = False
  417. elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
  418. logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones")
  419. return
  420. q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0])
  421. if q_nodes is None:
  422. logger.debug("fuse_attention: failed to match q path")
  423. return
  424. (transpose_q, reshape_q, split_q) = q_nodes
  425. if split_fc != split_q:
  426. logger.debug("fuse_attention: skip since split_fc != split_q")
  427. return
  428. k_nodes = self.model.match_parent_path(matmul_qk, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
  429. if k_nodes is None:
  430. # This pattern is from pytorch 1.7.1 and transformers 4.6.1
  431. k_nodes = self.model.match_parent_path(
  432. matmul_qk,
  433. ["Transpose", "Concat", "Transpose", "Reshape", "Split"],
  434. [1, 0, 1, 0, 0],
  435. )
  436. if k_nodes is None:
  437. logger.debug("fuse_attention: failed to match k path")
  438. return
  439. else:
  440. (_, concat_k, transpose_k, reshape_k, split_k) = k_nodes
  441. else:
  442. (concat_k, transpose_k, reshape_k, split_k) = k_nodes
  443. if split_fc != split_k:
  444. logger.debug("fuse_attention: skip since split_fc != split_k")
  445. return
  446. if concat_k_to_match and concat_k != concat_k_to_match:
  447. logger.debug("fuse_attention: skip since concat_k != concat_k_to_match")
  448. return
  449. attention_mask_input_name = ""
  450. if input_mask_nodes is not None:
  451. input_name = input_mask_nodes[-1].input[0]
  452. attention_mask_input_name = self.cast_attention_mask(input_name)
  453. # Match past and present paths
  454. past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or self.match_past_pattern_2(
  455. concat_k, concat_v, output_name_to_node
  456. )
  457. if past is None:
  458. logger.info("fuse_attention: failed to match past path")
  459. return
  460. if not self.model.find_graph_input(past):
  461. logger.debug("past is not graph input.")
  462. # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
  463. present = self.match_present(concat_v, input_name_to_nodes)
  464. if present is None:
  465. logger.info("fuse_attention: failed to match present path")
  466. return
  467. if not self.model.find_graph_output(present):
  468. logger.info("expect present to be graph output")
  469. return
  470. self.create_attention_node(
  471. fc_weight,
  472. fc_bias,
  473. gemm_qkv,
  474. past,
  475. present,
  476. layernorm_before_attention.output[0],
  477. reshape_qkv.output[0],
  478. attention_mask_input_name,
  479. is_unidirectional,
  480. )
  481. # we rely on prune_graph() to clean old subgraph nodes:
  482. # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
  483. self.prune_graph = True