onnx_model_bart.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_attention import AttentionMask
  7. from fusion_bart_attention import FusionBartAttention
  8. from fusion_options import FusionOptions
  9. from fusion_reshape import FusionReshape
  10. from onnx import numpy_helper
  11. from onnx_model import OnnxModel
  12. from onnx_model_bert import BertOnnxModel
  13. logger = logging.getLogger(__name__)
  14. class FusionBartReshape(FusionReshape):
  15. def __init__(self, model: OnnxModel):
  16. super().__init__(model)
  17. def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
  18. if reshape_node.input[1] not in output_name_to_node:
  19. return
  20. concat_node = output_name_to_node[reshape_node.input[1]]
  21. if concat_node.op_type != "Concat" or len(concat_node.input) != 4:
  22. return
  23. path0 = self.model.match_parent_path(
  24. concat_node,
  25. ["Unsqueeze", "Gather", "Shape"],
  26. [0, 0, 0],
  27. output_name_to_node,
  28. )
  29. if path0 is None:
  30. return
  31. (_, gather_0, shape_0) = path0
  32. shape = []
  33. gather_value = self.model.get_constant_value(gather_0.input[1])
  34. if gather_value == 0:
  35. shape.append(0)
  36. path1 = self.model.match_parent_path(
  37. concat_node,
  38. ["Unsqueeze", "Gather", "Shape"],
  39. [1, 0, 0],
  40. output_name_to_node,
  41. )
  42. if path1 is None:
  43. input_1_proto = self.model.get_initializer(concat_node.input[1])
  44. input_2_proto = self.model.get_initializer(concat_node.input[2])
  45. input_3_proto = self.model.get_initializer(concat_node.input[3])
  46. if input_1_proto is None or input_2_proto is None or input_3_proto is None:
  47. return
  48. input_1 = numpy_helper.to_array(input_1_proto)
  49. input_2 = numpy_helper.to_array(input_2_proto)
  50. input_3 = numpy_helper.to_array(input_3_proto)
  51. if len(input_1) != 1 or len(input_2) != 1 or len(input_3) != 1:
  52. return
  53. if not (input_1[0] == -1 and input_2[0] > 0 and input_3[0] > 0):
  54. return
  55. shape.extend(input_1)
  56. shape.extend(input_2)
  57. shape.extend(input_3)
  58. gemm_path_with_bias = self.model.match_parent_path(
  59. reshape_node, ["Add", "MatMul"], [0, 1], output_name_to_node
  60. )
  61. gemm_path_no_bias = self.model.match_parent_path(reshape_node, ["MatMul"], [0], output_name_to_node)
  62. if gemm_path_with_bias is not None:
  63. gemm_path = gemm_path_with_bias
  64. elif gemm_path_no_bias is not None:
  65. gemm_path = gemm_path_no_bias
  66. else:
  67. return
  68. top_matmul = gemm_path[-1]
  69. root_input = top_matmul.input[0]
  70. self.replace_reshape_node(shape, reshape_node, concat_node)
  71. else:
  72. (_, gather_1, shape_1) = path1
  73. gather_value = self.model.get_constant_value(gather_1.input[1])
  74. if gather_value == 1:
  75. shape.append(0)
  76. input_2_proto = self.model.get_initializer(concat_node.input[2])
  77. input_3_proto = self.model.get_initializer(concat_node.input[3])
  78. if input_2_proto is None or input_3_proto is None:
  79. return
  80. input_2 = numpy_helper.to_array(input_2_proto)
  81. input_3 = numpy_helper.to_array(input_3_proto)
  82. if len(input_2) != 1 or len(input_3) != 1:
  83. return
  84. if not (input_2[0] > 0 and input_3[0] > 0):
  85. return
  86. shape.extend(input_2)
  87. shape.extend(input_3)
  88. gemm_path = self.model.match_parent_path(
  89. reshape_node, ["Mul", "Add", "MatMul"], [0, 0, 1], output_name_to_node
  90. )
  91. if gemm_path is None:
  92. return
  93. top_matmul = gemm_path[-1]
  94. root_input = top_matmul.input[0]
  95. if shape_0.input[0] != root_input or shape_1.input[0] != root_input:
  96. return
  97. self.replace_reshape_node(shape, reshape_node, concat_node)
  98. class BartOnnxModel(BertOnnxModel):
  99. def __init__(self, model, num_heads, hidden_size, model_impl="hf"):
  100. super().__init__(model, num_heads, hidden_size)
  101. self.attention_mask = AttentionMask(self)
  102. self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
  103. self.bart_reshape_fusion_preprocess = FusionBartReshape(self)
  104. def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
  105. self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
  106. self.attention_fusion.disable_multi_head_attention_bias = (
  107. False if options is None else options.disable_multi_head_attention_bias
  108. )
  109. super().optimize(options, add_dynamic_axes)
  110. def fuse_attention(self):
  111. self.attention_fusion.apply()
  112. def preprocess(self):
  113. self.adjust_reshape_and_expand()
  114. self.bart_reshape_fusion_preprocess.apply()