onnx_model_phi.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from dynamo_onnx_helper import DynamoOnnxHelper
  8. from fusion_base import Fusion
  9. from fusion_options import AttentionOpType, FusionOptions
  10. from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
  11. from fusion_utils import NumpyHelper
  12. from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper
  13. from onnx_model import OnnxModel
  14. logger = getLogger(__name__)
  15. class ProcessGemmWFunc:
  16. def __call__(self, x):
  17. return np.transpose(x, (1, 0))
  18. class ProcessMatMulQFunc:
  19. def __call__(self, x):
  20. return np.transpose(np.split(x, 3, 0)[0], (1, 0))
  21. class ProcessMatMulKFunc:
  22. def __call__(self, x):
  23. return np.transpose(np.split(x, 3, 0)[1], (1, 0))
  24. class ProcessMatMulVFunc:
  25. def __call__(self, x):
  26. return np.transpose(np.split(x, 3, 0)[2], (1, 0))
  27. class ProcessBiasQFunc:
  28. def __call__(self, x):
  29. x = np.split(x, 3, -1)[0]
  30. return x
  31. class ProcessBiasKFunc:
  32. def __call__(self, x):
  33. x = np.split(x, 3, -1)[1]
  34. return x
  35. class ProcessBiasVFunc:
  36. def __call__(self, x):
  37. x = np.split(x, 3, -1)[2]
  38. return x
  39. class ProcessRotCacheFunc:
  40. def __call__(self, x):
  41. # half rotary embedding
  42. assert len(x.shape) == 2
  43. if x.shape[1] == 32:
  44. return x[:, 0:16]
  45. return x
  46. # TODO: move to a separate file
  47. class Fission(Fusion):
  48. def __init__(
  49. self,
  50. model: OnnxModel,
  51. nodes_to_find: list[str],
  52. ):
  53. super().__init__(model, "DONOTUSE", nodes_to_find)
  54. def set_attention_op_type(self, attn_op_type: AttentionOpType):
  55. self.attn_op_type = attn_op_type
  56. def get_uname(self, layer_id, name):
  57. return name + "_" + str(layer_id)
  58. def get_edge_by_name(self, edges, name):
  59. for edge in edges:
  60. if edge == name or edge.endswith(name) or edge.startswith(name):
  61. return edge
  62. raise ValueError(f"Edge {name} not found")
  63. def get_input_by_name(self, node, name):
  64. return self.get_edge_by_name(node.input, name)
  65. def get_output_by_name(self, node, name):
  66. return self.get_edge_by_name(node.output, name)
  67. def process_initializer(self, initializer_name, functor, custom_name=None):
  68. i = self.model.get_initializer(initializer_name)
  69. i_np_array = NumpyHelper.to_array(i)
  70. processed_i_np_array = functor(i_np_array)
  71. new_tensor = helper.make_tensor(
  72. initializer_name + "_processed" if custom_name is None else custom_name,
  73. data_type=TensorProto.FLOAT,
  74. dims=processed_i_np_array.shape,
  75. vals=processed_i_np_array.flatten().tobytes(),
  76. raw=True,
  77. )
  78. self.model.add_initializer(new_tensor, self.this_graph_name)
  79. return new_tensor.name
  80. def add_fp32_value_info(self, name):
  81. new_value_info = self.model.graph().value_info.add()
  82. new_value_info.name = name
  83. new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
  84. def add_int64_value_info(self, name):
  85. new_value_info = self.model.graph().value_info.add()
  86. new_value_info.name = name
  87. new_value_info.type.tensor_type.elem_type = TensorProto.INT64
  88. def replace_fp32_value_info(self, name, shape):
  89. for value_info in self.model.graph().value_info:
  90. if value_info.name == name:
  91. self.model.graph().value_info.remove(value_info)
  92. break
  93. new_value_info = helper.make_tensor_value_info(
  94. name,
  95. elem_type=TensorProto.FLOAT,
  96. shape=shape,
  97. )
  98. self.model.graph().value_info.extend([new_value_info])
  99. def set_unique_name_and_add_nodes(
  100. self, subgraph_nodes: list[NodeProto], layer_id: int, layer_known_edges_names: list[str]
  101. ):
  102. for new_node in subgraph_nodes:
  103. for i, name in enumerate(new_node.input):
  104. if name == "":
  105. continue
  106. elif name not in layer_known_edges_names:
  107. new_node.input[i] = self.get_uname(layer_id, name)
  108. self.add_fp32_value_info(new_node.input[i])
  109. for i, name in enumerate(new_node.output):
  110. if name == "":
  111. continue
  112. elif name not in layer_known_edges_names:
  113. new_node.output[i] = self.get_uname(layer_id, name)
  114. self.add_fp32_value_info(new_node.output[i])
  115. new_node.name = self.get_uname(layer_id, new_node.name)
  116. self.nodes_to_add.append(new_node)
  117. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  118. def layernorm(self, inputs: list[str], outputs: list[str], prefix: str = ""):
  119. assert len(inputs) == 3
  120. assert len(outputs) == 1
  121. node = helper.make_node(
  122. "LayerNormalization",
  123. inputs=inputs,
  124. outputs=outputs,
  125. name=prefix + "_LayerNormalization",
  126. epsilon=9.999999747378752e-06,
  127. )
  128. return [node]
  129. def gemm(self, inputs: list[str], outputs: list[str], prefix: str = ""):
  130. assert len(inputs) == 3
  131. assert len(outputs) == 1
  132. matmul = helper.make_node(
  133. "MatMul",
  134. inputs=[inputs[0], inputs[1]],
  135. outputs=[prefix + "matmul_out"],
  136. name=prefix + "MatMul",
  137. )
  138. add = helper.make_node(
  139. "Add",
  140. inputs=[prefix + "matmul_out", inputs[2]],
  141. outputs=outputs,
  142. name=prefix + "Bias",
  143. )
  144. return [matmul, add]
  145. def rotary(self, inputs: list[str], outputs: list[str], prefix: str = "", rot_dim=32, num_heads=32):
  146. assert len(inputs) == 4
  147. assert len(outputs) == 1
  148. node = helper.make_node(
  149. "RotaryEmbedding",
  150. inputs=inputs,
  151. outputs=outputs,
  152. name=prefix + "RotaryEmbedding",
  153. domain="com.microsoft",
  154. rotary_embedding_dim=rot_dim,
  155. num_heads=num_heads,
  156. )
  157. return [node]
  158. def fastgelu(self, inputs: list[str], outputs: list[str], prefix: str = ""):
  159. assert len(inputs) == 1
  160. assert len(outputs) == 1
  161. node = helper.make_node(
  162. "FastGelu",
  163. inputs=inputs,
  164. outputs=outputs,
  165. name=prefix + "FastGelu",
  166. domain="com.microsoft",
  167. )
  168. return [node]
  169. def add(self, inputs: list[str], outputs: list[str], prefix: str = ""):
  170. assert len(inputs) == 2
  171. assert len(outputs) == 1
  172. node = helper.make_node(
  173. "Add",
  174. inputs=inputs,
  175. outputs=outputs,
  176. name=prefix + "Add",
  177. )
  178. return [node]
  179. def mha(self, inputs: list[str], outputs: list[str], prefix: str = "", num_heads=32):
  180. assert len(inputs) == 8
  181. assert len(outputs) == 3
  182. node = helper.make_node(
  183. "MultiHeadAttention",
  184. inputs=inputs,
  185. outputs=outputs,
  186. name=prefix + "MultiHeadAttention",
  187. domain="com.microsoft",
  188. num_heads=num_heads,
  189. unidirectional=1,
  190. )
  191. return [node]
  192. def gqa(self, inputs: list[str], outputs: list[str], prefix: str = "", num_heads=32):
  193. assert len(inputs) == 7
  194. assert len(outputs) == 3
  195. node = helper.make_node(
  196. "GroupQueryAttention",
  197. inputs=inputs,
  198. outputs=outputs,
  199. name=prefix + "GroupQueryAttention",
  200. domain="com.microsoft",
  201. num_heads=num_heads,
  202. kv_num_heads=num_heads,
  203. )
  204. return [node]
  205. def attention(self, inputs: list[str], outputs: list[str], prefix: str = "", num_heads=32):
  206. assert len(inputs) == 5
  207. assert len(outputs) == 2
  208. node = helper.make_node(
  209. "Attention",
  210. inputs=inputs,
  211. outputs=outputs,
  212. name=prefix + "Attention",
  213. domain="com.microsoft",
  214. num_heads=num_heads,
  215. unidirectional=1,
  216. do_rotary=1,
  217. rotary_embedding_dim=32,
  218. )
  219. return [node]
  220. def paged_attn(
  221. self,
  222. inputs: list[str],
  223. outputs: list[str],
  224. prefix: str = "",
  225. num_heads=32,
  226. head_size=80,
  227. scale=0.11180339753627777,
  228. ):
  229. assert len(inputs) == 6
  230. assert len(outputs) == 1
  231. node = helper.make_node(
  232. "PagedAttention",
  233. inputs=inputs,
  234. outputs=outputs,
  235. name=prefix + "PagedAttention",
  236. domain="vllm.ort.ext",
  237. num_heads=num_heads,
  238. num_kv_heads=num_heads,
  239. head_size=head_size,
  240. scale=scale,
  241. )
  242. return [node]
  243. class Phi2PreProcessor(DynamoOnnxHelper):
  244. def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
  245. super().__init__(model)
  246. self.num_hidden_layers = 32
  247. self.num_attention_heads = num_heads
  248. self.hidden_size = hidden_size
  249. self.func_name = "modeling_phi_PhiModel_model_1"
  250. def get_phi2_edge_dict(self) -> dict:
  251. edge_dict = {}
  252. edge_dict["lm_head_1"] = "logits"
  253. edge_dict["l_input_ids_"] = "input_ids"
  254. edge_dict["key_states"] = "past_key_0"
  255. edge_dict["value_states"] = "past_value_0"
  256. for i in range(1, self.num_hidden_layers, 1):
  257. edge_dict[f"key_states_{i}"] = f"past_key_{i}"
  258. edge_dict[f"value_states_{i}"] = f"past_value_{i}"
  259. edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}"
  260. edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"
  261. outputs = [o.name for o in self.model.graph.output]
  262. if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs:
  263. edge_dict["model_layers_0_1_1"] = "present_key_0"
  264. edge_dict["model_layers_0_1_2"] = "present_value_0"
  265. else:
  266. assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs
  267. edge_dict["model_layers_0_1"] = "present_key_0"
  268. edge_dict["model_layers_0_1_1"] = "present_value_0"
  269. return edge_dict
  270. def simplify_phi2_op_type(self):
  271. phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers"
  272. for node in self.model.graph.node:
  273. index = node.op_type.find(phi2_transformer_layer_name)
  274. if index != -1:
  275. node.op_type = node.op_type[index:]
  276. def process_graph_io(self, attn_op_type: AttentionOpType):
  277. self.use_attn = attn_op_type == AttentionOpType.Attention
  278. self.use_vllm = attn_op_type == AttentionOpType.PagedAttention
  279. graph = self.model.graph
  280. new_inputs = []
  281. for vi in graph.input:
  282. if "input_ids" in vi.name:
  283. vi_iid = helper.make_tensor_value_info(
  284. vi.name,
  285. elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64,
  286. shape=["batch_size", "seq_len"],
  287. )
  288. vi_step = helper.make_tensor_value_info(
  289. "step",
  290. elem_type=TensorProto.INT64,
  291. shape=[1],
  292. )
  293. vi_pid = helper.make_tensor_value_info(
  294. "position_ids",
  295. elem_type=TensorProto.INT64,
  296. shape=["batch_size", "seq_len"],
  297. )
  298. vi_mask = helper.make_tensor_value_info(
  299. "attention_mask",
  300. elem_type=TensorProto.INT32,
  301. shape=["batch_size", "seq_len"],
  302. )
  303. vi_meta = helper.make_tensor_value_info(
  304. "input_metadata",
  305. elem_type=TensorProto.INT64,
  306. shape=[1],
  307. )
  308. (
  309. new_inputs.extend([vi_iid, vi_step, vi_mask])
  310. if not self.use_vllm
  311. else new_inputs.extend([vi_iid, vi_pid, vi_meta])
  312. )
  313. if self.use_attn:
  314. if "past_key" in vi.name:
  315. vi_cache = helper.make_tensor_value_info(
  316. vi.name.replace("past_key", "past"),
  317. elem_type=vi.type.tensor_type.elem_type,
  318. shape=[
  319. 2,
  320. "batch_size",
  321. self.num_attention_heads,
  322. "past_seq_len",
  323. self.hidden_size // self.num_attention_heads,
  324. ],
  325. )
  326. new_inputs.extend([vi_cache])
  327. elif self.use_vllm:
  328. if "past_key" in vi.name:
  329. vi_cache = helper.make_tensor_value_info(
  330. vi.name,
  331. elem_type=vi.type.tensor_type.elem_type,
  332. shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"],
  333. )
  334. new_inputs.extend([vi_cache])
  335. if "past_value" in vi.name:
  336. vi_cache = helper.make_tensor_value_info(
  337. vi.name,
  338. elem_type=vi.type.tensor_type.elem_type,
  339. shape=[
  340. "num_blocks",
  341. "num_heads",
  342. "head_size",
  343. "block_size",
  344. ],
  345. )
  346. new_inputs.extend([vi_cache])
  347. else:
  348. if "past_key" in vi.name or "past_value" in vi.name:
  349. vi_cache = helper.make_tensor_value_info(
  350. vi.name,
  351. elem_type=vi.type.tensor_type.elem_type,
  352. shape=[
  353. "batch_size",
  354. self.num_attention_heads,
  355. "past_seq_len",
  356. self.hidden_size // self.num_attention_heads,
  357. ],
  358. )
  359. new_inputs.extend([vi_cache])
  360. graph.ClearField("input")
  361. graph.input.extend(new_inputs)
  362. new_outputs = []
  363. for i, vi in enumerate(graph.output):
  364. if i == 0:
  365. new_outputs.extend([vi])
  366. else:
  367. if self.use_attn:
  368. if "present_key" in vi.name:
  369. vi_cache = helper.make_tensor_value_info(
  370. vi.name.replace("present_key", "present"),
  371. elem_type=vi.type.tensor_type.elem_type,
  372. shape=[
  373. 2,
  374. "batch_size",
  375. self.num_attention_heads,
  376. "total_seq_len",
  377. self.hidden_size // self.num_attention_heads,
  378. ],
  379. )
  380. new_outputs.extend([vi_cache])
  381. elif self.use_vllm:
  382. pass
  383. else:
  384. vi_cache = helper.make_tensor_value_info(
  385. vi.name,
  386. elem_type=vi.type.tensor_type.elem_type,
  387. shape=[
  388. "batch_size",
  389. self.num_attention_heads,
  390. "total_seq_len",
  391. self.hidden_size // self.num_attention_heads,
  392. ],
  393. )
  394. new_outputs.extend([vi_cache])
  395. graph.ClearField("output")
  396. graph.output.extend(new_outputs)
  397. def preprocess_onnx(self, attn_op_type: AttentionOpType):
  398. function_name = None
  399. for func in self.model.functions:
  400. if func.name.endswith(self.func_name):
  401. function_name = func.name
  402. break
  403. assert function_name is not None
  404. self.unroll_function(function_name)
  405. self.update_edges(self.get_phi2_edge_dict())
  406. self.simplify_phi2_op_type()
  407. self.remove_dropout_layer()
  408. if attn_op_type == AttentionOpType.PagedAttention:
  409. self.remove_lm_head_layer()
  410. self.process_graph_io(attn_op_type)
  411. class FissionTransformerEmbeddingPhi(Fission):
  412. def __init__(
  413. self,
  414. model: OnnxModel,
  415. ):
  416. super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"])
  417. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  418. logger.info("Optimizing %s...", node.name)
  419. assert len(node.input) == 2
  420. assert len(node.output) == 1
  421. input = node.input[0]
  422. output = node.output[0]
  423. embedding = self.get_input_by_name(node, "embed_tokens.weight")
  424. layer_known_edges_names = [input, output, embedding]
  425. subgraph_nodes = [
  426. helper.make_node(
  427. "Gather",
  428. inputs=[embedding, input],
  429. outputs=[output],
  430. name="Embedding_Gather",
  431. ),
  432. ]
  433. self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names)
  434. self.nodes_to_remove.append(node)
  435. self.prune_graph = True
  436. class FissionTransformerLayerNormPhi(Fission):
  437. def __init__(
  438. self,
  439. model: OnnxModel,
  440. ):
  441. super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"])
  442. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  443. logger.info("Optimizing %s...", node.name)
  444. assert len(node.input) == 3
  445. assert len(node.output) == 1
  446. input = node.input[0]
  447. output = node.output[0]
  448. ln_weight = self.get_input_by_name(node, "final_layernorm.weight")
  449. ln_bias = self.get_input_by_name(node, "final_layernorm.bias")
  450. layer_known_edges_names = [input, output, ln_weight, ln_bias]
  451. subgraph_nodes = []
  452. subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final"))
  453. self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
  454. self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
  455. self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"])
  456. self.nodes_to_remove.append(node)
  457. self.prune_graph = True
  458. class FissionTransformerCausalLMHeadPhi(Fission):
  459. def __init__(
  460. self,
  461. model: OnnxModel,
  462. ):
  463. super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"])
  464. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  465. logger.info("Optimizing %s...", node.name)
  466. assert len(node.input) == 5
  467. assert len(node.output) == 1
  468. input = node.input[2]
  469. output = node.output[0]
  470. fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
  471. fc_bias = self.get_input_by_name(node, "lm_head.bias")
  472. layer_known_edges_names = [input, output, fc_weight, fc_bias]
  473. subgraph_nodes = []
  474. subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_"))
  475. self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
  476. self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
  477. self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200])
  478. self.nodes_to_remove.append(node)
  479. self.prune_graph = True
  480. class FissionTransformerBlockPhi(Fission):
  481. def __init__(
  482. self,
  483. model: OnnxModel,
  484. num_heads: int,
  485. ):
  486. self.num_heads = num_heads
  487. max_num_layers = 32
  488. self.func_to_layer_id = {}
  489. nodes_to_find = []
  490. for layer in range(max_num_layers):
  491. func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1"
  492. nodes_to_find.append(func_name)
  493. self.func_to_layer_id[func_name] = layer
  494. super().__init__(model, nodes_to_find)
  495. def get_layer_id(self, node):
  496. return self.func_to_layer_id[node.op_type]
  497. def get_gqa_aux_nodes(self):
  498. gqa_aux_nodes = [
  499. helper.make_node(
  500. "Cast",
  501. inputs=["attention_mask"],
  502. outputs=["mask_int64"],
  503. name="Cast_gqa_aux_0",
  504. to=TensorProto.INT64,
  505. ),
  506. helper.make_node(
  507. "ReduceSum",
  508. inputs=["mask_int64", "one"],
  509. outputs=["mask_row_sums"],
  510. name="ReduceSum_gqa_aux",
  511. ),
  512. helper.make_node(
  513. "Sub",
  514. inputs=["mask_row_sums", "one"],
  515. outputs=["seqlens_k_int64"],
  516. name="Sub_gqa_aux",
  517. ),
  518. helper.make_node(
  519. "Cast",
  520. inputs=["seqlens_k_int64"],
  521. outputs=["seqlens_k"],
  522. name="Cast_gqa_aux_1",
  523. to=TensorProto.INT32,
  524. ),
  525. helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"),
  526. helper.make_node(
  527. "Gather",
  528. inputs=["mask_shape", "one"],
  529. outputs=["total_seq_len_int64"],
  530. name="Gather_gqa_aux_0",
  531. axis=0,
  532. ),
  533. helper.make_node(
  534. "Cast",
  535. inputs=["total_seq_len_int64"],
  536. outputs=["total_sequence_length"],
  537. name="Cast_gqa_aux_2",
  538. to=TensorProto.INT32,
  539. ),
  540. ]
  541. return gqa_aux_nodes
  542. def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name):
  543. q_weight = self.model.get_initializer(q_w)
  544. k_weight = self.model.get_initializer(k_w)
  545. v_weight = self.model.get_initializer(v_w)
  546. qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0))
  547. kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0))
  548. vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0))
  549. qkv_weight = np.stack((qw, kw, vw), axis=1)
  550. q_bias = self.model.get_initializer(q_b)
  551. k_bias = self.model.get_initializer(k_b)
  552. v_bias = self.model.get_initializer(v_b)
  553. qb = NumpyHelper.to_array(q_bias)
  554. kb = NumpyHelper.to_array(k_bias)
  555. vb = NumpyHelper.to_array(v_bias)
  556. qkv_bias = np.stack((qb, kb, vb), axis=0)
  557. hidden_size = qkv_weight.shape[0]
  558. weight = helper.make_tensor(
  559. weight_name,
  560. data_type=TensorProto.FLOAT,
  561. dims=[hidden_size, hidden_size * 3],
  562. vals=qkv_weight.flatten().tobytes(),
  563. raw=True,
  564. )
  565. self.model.add_initializer(weight, self.this_graph_name)
  566. bias = helper.make_tensor(
  567. bias_name,
  568. data_type=TensorProto.FLOAT,
  569. dims=[hidden_size * 3],
  570. vals=qkv_bias.flatten().tobytes(),
  571. raw=True,
  572. )
  573. self.model.add_initializer(bias, self.this_graph_name)
  574. self.add_fp32_value_info(weight.name)
  575. self.add_fp32_value_info(bias.name)
  576. return weight_name, bias_name
  577. def fuse(
  578. self,
  579. node,
  580. input_name_to_nodes,
  581. output_name_to_node,
  582. ):
  583. logger.info("Optimizing %s...", node.name)
  584. logger.info(f"AttentionOpType: {self.attn_op_type}")
  585. layer_id = self.get_layer_id(node)
  586. i_hidden_states = node.input[0]
  587. i_key_cache = self.get_input_by_name(node, "past_key")
  588. i_value_cache = self.get_input_by_name(node, "past_value")
  589. o_hidden_states = node.output[-1]
  590. o_key_cache = self.get_output_by_name(node, "present_key")
  591. o_value_cache = self.get_output_by_name(node, "present_value")
  592. ln_weight = self.get_input_by_name(node, "input_layernorm.weight")
  593. ln_bias = self.get_input_by_name(node, "input_layernorm.bias")
  594. attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = (
  595. None,
  596. None,
  597. None,
  598. None,
  599. None,
  600. None,
  601. )
  602. attn_qkv_weight, attn_qkv_bias = None, None
  603. cos_cache, sin_cache = None, None
  604. if self.attn_op_type != AttentionOpType.Attention:
  605. attn_q_weight = self.process_initializer(
  606. self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
  607. )
  608. attn_k_weight = self.process_initializer(
  609. self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
  610. )
  611. attn_v_weight = self.process_initializer(
  612. self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
  613. )
  614. attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias")
  615. attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias")
  616. attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias")
  617. cos_cache = self.process_initializer(
  618. self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
  619. )
  620. sin_cache = self.process_initializer(
  621. self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
  622. )
  623. else:
  624. attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
  625. self.get_input_by_name(node, "self_attn.q_proj.weight"),
  626. self.get_input_by_name(node, "self_attn.k_proj.weight"),
  627. self.get_input_by_name(node, "self_attn.v_proj.weight"),
  628. self.get_input_by_name(node, "self_attn.q_proj.bias"),
  629. self.get_input_by_name(node, "self_attn.k_proj.bias"),
  630. self.get_input_by_name(node, "self_attn.v_proj.bias"),
  631. self.get_uname(layer_id, "attn_qkv_weight"),
  632. self.get_uname(layer_id, "attn_qkv_bias"),
  633. )
  634. attn_out_weight = self.process_initializer(
  635. self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
  636. )
  637. attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias")
  638. mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
  639. mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
  640. mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias")
  641. mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias")
  642. layer_known_edges_names = []
  643. layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache])
  644. layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache])
  645. layer_known_edges_names.extend([ln_weight, ln_bias])
  646. if self.attn_op_type != AttentionOpType.Attention:
  647. layer_known_edges_names.extend(
  648. [
  649. attn_q_weight,
  650. attn_q_bias,
  651. attn_k_weight,
  652. attn_k_bias,
  653. attn_v_weight,
  654. attn_v_bias,
  655. cos_cache,
  656. sin_cache,
  657. ]
  658. )
  659. else:
  660. layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias])
  661. layer_known_edges_names.extend(
  662. [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]
  663. )
  664. layer_known_edges_names.extend(
  665. ["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"]
  666. )
  667. subgraph_nodes = []
  668. subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"]))
  669. subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_"))
  670. subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_"))
  671. subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"]))
  672. subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_"))
  673. subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1"))
  674. subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2"))
  675. if self.attn_op_type != AttentionOpType.Attention:
  676. subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
  677. subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
  678. subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
  679. # vllm engine requires full position ids as the input
  680. pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step"
  681. subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_"))
  682. subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))
  683. if self.attn_op_type == AttentionOpType.MultiHeadAttention:
  684. subgraph_nodes.extend(
  685. self.mha(
  686. ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache],
  687. ["attn_out", o_key_cache, o_value_cache],
  688. )
  689. )
  690. elif self.attn_op_type == AttentionOpType.GroupQueryAttention:
  691. subgraph_nodes.extend(
  692. self.gqa(
  693. [
  694. "query_rot",
  695. "key_rot",
  696. "value",
  697. i_key_cache,
  698. i_value_cache,
  699. "seqlens_k",
  700. "total_sequence_length",
  701. ],
  702. ["attn_out", o_key_cache, o_value_cache],
  703. )
  704. )
  705. if layer_id == 0:
  706. gqa_aux_nodes = self.get_gqa_aux_nodes()
  707. for new_node in gqa_aux_nodes:
  708. self.nodes_to_add.append(new_node)
  709. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  710. self.model.add_initializer(
  711. numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
  712. )
  713. elif self.attn_op_type == AttentionOpType.PagedAttention:
  714. subgraph_nodes.extend(
  715. self.paged_attn(
  716. ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"],
  717. ["attn_out"],
  718. )
  719. )
  720. else:
  721. past_name = f"past_{layer_id}"
  722. present_name = f"present_{layer_id}"
  723. layer_known_edges_names.extend([past_name, present_name])
  724. subgraph_nodes.extend(
  725. self.attention(
  726. ["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name]
  727. )
  728. )
  729. self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names)
  730. self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"])
  731. self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"])
  732. self.nodes_to_remove.append(node)
  733. self.prune_graph = True
  734. class PhiOnnxModel(OnnxModel):
  735. def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
  736. super().__init__(model)
  737. self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size)
  738. self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads)
  739. self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self)
  740. self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self)
  741. self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self)
  742. def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
  743. assert options is not None
  744. attn_op_type = options.attention_op_type
  745. self.fission_transformer_block.set_attention_op_type(attn_op_type)
  746. self.phi2_preprocessor.preprocess_onnx(attn_op_type)
  747. self.fission_transformer_block.apply()
  748. self.fission_transformer_layernorm.apply()
  749. self.fission_causal_lm_head.apply()
  750. self.fission_transformer_embedding.apply()
  751. super().prune_graph()
  752. # SLN ctor is placed here intentionally to delay the symbolic shape inference
  753. self.fuse_sln = FusionSkipLayerNormalization(self)
  754. self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self)
  755. self.fuse_sln.apply()
  756. self.fuse_bias_sln.apply()
  757. def get_fused_operator_statistics(self):
  758. """
  759. Returns node count of fused operators.
  760. """
  761. op_count = {}
  762. ops = [
  763. "Attention",
  764. "MultiHeadAttention",
  765. "GroupQueryAttention",
  766. "PagedAttention",
  767. "Gelu",
  768. "BiasGelu",
  769. "FastGelu",
  770. "LayerNormalization",
  771. "SkipLayerNormalization",
  772. ]
  773. for op in ops:
  774. nodes = self.get_nodes_by_op_type(op)
  775. op_count[op] = len(nodes)
  776. logger.info(f"Optimized operators: {op_count}")
  777. return op_count
  778. def is_fully_optimized(self, fused_op_count=None):
  779. """
  780. Returns True when the model is fully optimized.
  781. """
  782. if fused_op_count is None:
  783. fused_op_count = self.get_fused_operator_statistics()
  784. def op_count(op_name: str):
  785. return fused_op_count.get(op_name) or 0
  786. attention = (
  787. op_count("Attention")
  788. + op_count("MultiHeadAttention")
  789. + op_count("GroupQueryAttention")
  790. + op_count("PagedAttention")
  791. )
  792. gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
  793. layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
  794. is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention)
  795. if layer_norm == 0:
  796. logger.debug("Layer Normalization not fused")
  797. if gelu == 0:
  798. logger.debug("Gelu (or FastGelu) not fused")
  799. if attention == 0:
  800. logger.warning("Attention (or MultiHeadAttention) not fused")
  801. return is_perfect