onnx_model_gpt2.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import onnx
  7. from fusion_gpt_attention import FusionGptAttention
  8. from fusion_gpt_attention_megatron import FusionGptAttentionMegatron
  9. from fusion_gpt_attention_no_past import FusionGptAttentionNoPast
  10. from fusion_rotary_attention import FusionRotaryAttention
  11. from onnx_model_bert import BertOnnxModel
  12. logger = logging.getLogger(__name__)
  13. class Gpt2OnnxModel(BertOnnxModel):
  14. def __init__(self, model, num_heads, hidden_size):
  15. super().__init__(model, num_heads, hidden_size)
  16. def fuse_attention(self):
  17. if len(self.model.graph.input) == 1 or len(self.model.graph.output) == 1:
  18. fusion = FusionGptAttentionNoPast(self, self.num_heads)
  19. fusion.apply()
  20. else:
  21. fusion = FusionGptAttention(self, self.num_heads)
  22. fusion.apply()
  23. fusion = FusionGptAttentionMegatron(self, self.num_heads)
  24. fusion.apply()
  25. fusion = FusionRotaryAttention(self, self.hidden_size, self.num_heads)
  26. fusion.apply()
  27. def postprocess(self):
  28. """
  29. Remove extra reshape nodes.
  30. """
  31. logger.debug("start postprocessing...")
  32. input_name_to_nodes = self.input_name_to_nodes()
  33. output_name_to_node = self.output_name_to_node()
  34. reshape_count = 0
  35. for gemm_node in self.get_nodes_by_op_type("Gemm"):
  36. reshape_after_gemm = self.find_first_child_by_type(
  37. gemm_node, "Reshape", input_name_to_nodes, recursive=False
  38. )
  39. nodes = self.match_parent_path(gemm_node, ["Reshape", "FastGelu"], [0, 0], output_name_to_node)
  40. if nodes is None:
  41. nodes = self.match_parent_path(
  42. gemm_node,
  43. ["Reshape", "LayerNormalization"],
  44. [0, 0],
  45. output_name_to_node,
  46. )
  47. if nodes is None:
  48. nodes = self.match_parent_path(
  49. gemm_node,
  50. ["Reshape", "SkipLayerNormalization"],
  51. [0, 0],
  52. output_name_to_node,
  53. )
  54. if nodes is None:
  55. continue
  56. (reshape_before_gemm, root_node) = nodes
  57. matmul_node_name = self.create_node_name("MatMul", "FullyConnect_MatMul")
  58. matmul_node = onnx.helper.make_node(
  59. "MatMul",
  60. inputs=[matmul_node_name + "_input", gemm_node.input[1]],
  61. outputs=[matmul_node_name + "_output"],
  62. name=matmul_node_name,
  63. )
  64. add_node_name = self.create_node_name("Add", "FullyConnect_Add")
  65. add_node = onnx.helper.make_node(
  66. "Add",
  67. inputs=[matmul_node_name + "_output", gemm_node.input[2]],
  68. outputs=[add_node_name + "_output"],
  69. name=add_node_name,
  70. )
  71. self.replace_input_of_all_nodes(reshape_after_gemm.output[0], add_node_name + "_output")
  72. # Link root node output with MatMul
  73. self.replace_input_of_all_nodes(root_node.output[0], matmul_node_name + "_input")
  74. root_node.output[0] = matmul_node_name + "_input"
  75. self.replace_input_of_all_nodes(reshape_after_gemm.output[0], add_node_name + "_output")
  76. self.add_node(matmul_node)
  77. self.add_node(add_node)
  78. reshape_count += 2
  79. self.prune_graph()
  80. logger.info(f"postprocess: remove Reshape count: {reshape_count}")