fusion_attention_vae.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_base import Fusion
  8. from onnx import NodeProto, TensorProto, helper, numpy_helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionAttentionVae(Fusion):
  12. """
  13. Fuse Attention subgraph of Vae Decoder into one Attention node.
  14. """
  15. def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int):
  16. super().__init__(model, "Attention", ["Softmax"])
  17. self.hidden_size = hidden_size
  18. self.num_heads = num_heads
  19. # Flags to show warning only once
  20. self.num_heads_warning = True
  21. self.hidden_size_warning = True
  22. def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, add_q: NodeProto) -> tuple[int, int]:
  23. """Detect num_heads and hidden_size from a reshape node.
  24. Args:
  25. reshape_q (NodeProto): reshape node for Q
  26. add_q (NodeProto): add node for Q
  27. Returns:
  28. Tuple[int, int]: num_heads and hidden_size
  29. """
  30. concat = self.model.get_parent(reshape_q, 1)
  31. if concat is None or len(concat.input) != 4:
  32. return self.num_heads, self.hidden_size # Fall back to user specified value
  33. value = self.model.get_constant_value(concat.input[2])
  34. if not (value is not None and isinstance(value, np.ndarray) and value.size == 1):
  35. return self.num_heads, self.hidden_size # Fall back to user specified value
  36. num_heads = int(value)
  37. if num_heads <= 0:
  38. return self.num_heads, self.hidden_size # Fall back to user specified value
  39. _, bias = self.model.get_constant_input(add_q)
  40. if (bias is None) or (not isinstance(bias, np.ndarray)) or bias.ndim != 1:
  41. return self.num_heads, self.hidden_size # Fall back to user specified value
  42. hidden_size = bias.shape[0]
  43. if self.num_heads > 0 and num_heads != self.num_heads:
  44. if self.num_heads_warning:
  45. logger.warning(
  46. "Detected number of attention heads is %d. Ignore --num_heads %d", num_heads, self.num_heads
  47. )
  48. self.num_heads_warning = False # Do not show the warning more than once
  49. if self.hidden_size > 0 and hidden_size != self.hidden_size:
  50. if self.hidden_size_warning:
  51. logger.warning("Detected hidden size is %d. Ignore --hidden_size %d", hidden_size, self.hidden_size)
  52. self.hidden_size_warning = False # Do not show the warning more than once
  53. return num_heads, hidden_size
  54. def create_attention_node(
  55. self,
  56. q_matmul: NodeProto,
  57. q_add: NodeProto,
  58. k_matmul: NodeProto,
  59. k_add: NodeProto,
  60. v_matmul: NodeProto,
  61. v_add: NodeProto,
  62. num_heads: int,
  63. hidden_size: int,
  64. input_name: str,
  65. output_name: str,
  66. ) -> NodeProto | None:
  67. """Create an Attention node.
  68. Args:
  69. q_matmul (NodeProto): MatMul node in fully connection for Q
  70. q_add (NodeProto): Add bias node in fully connection for Q
  71. k_matmul (NodeProto): MatMul node in fully connection for K
  72. k_add (NodeProto): Add bias node in fully connection for K
  73. v_matmul (NodeProto): MatMul node in fully connection for V
  74. v_add (NodeProto): Add bias node in fully connection for V
  75. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  76. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
  77. input_name (str): input name
  78. output_name (str): output name
  79. Returns:
  80. Union[NodeProto, None]: the node created or None if failed.
  81. """
  82. if q_matmul.input[0] != input_name or k_matmul.input[0] != input_name or v_matmul.input[0] != input_name:
  83. logger.debug(
  84. "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
  85. q_matmul.input[0],
  86. k_matmul.input[0],
  87. v_matmul.input[0],
  88. )
  89. return None
  90. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  91. logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
  92. return None
  93. q_weight_tensor = self.model.get_initializer(q_matmul.input[1])
  94. k_weight_tensor = self.model.get_initializer(k_matmul.input[1])
  95. v_weight_tensor = self.model.get_initializer(v_matmul.input[1])
  96. if not (q_weight_tensor and k_weight_tensor and v_weight_tensor):
  97. return None
  98. q_bias_tensor = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
  99. k_bias_tensor = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
  100. v_bias_tensor = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
  101. q_bias = numpy_helper.to_array(q_bias_tensor)
  102. k_bias = numpy_helper.to_array(k_bias_tensor)
  103. v_bias = numpy_helper.to_array(v_bias_tensor)
  104. q_bias_shape = np.prod(q_bias.shape)
  105. k_bias_shape = np.prod(k_bias.shape)
  106. v_bias_shape = np.prod(v_bias.shape)
  107. # Sometimes weights are stored in fp16
  108. if q_weight_tensor.data_type == 10:
  109. logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
  110. return None
  111. q_weight = numpy_helper.to_array(q_weight_tensor)
  112. k_weight = numpy_helper.to_array(k_weight_tensor)
  113. v_weight = numpy_helper.to_array(v_weight_tensor)
  114. # assert q and k have same shape as expected
  115. if q_weight.shape != k_weight.shape or q_weight.shape != v_weight.shape:
  116. return None
  117. qw_in_size = q_weight.shape[0]
  118. kw_in_size = k_weight.shape[0]
  119. vw_in_size = v_weight.shape[0]
  120. assert qw_in_size == kw_in_size and kw_in_size == vw_in_size
  121. if hidden_size > 0 and hidden_size != qw_in_size:
  122. raise ValueError(
  123. f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
  124. "Please provide a correct input hidden size or pass in 0"
  125. )
  126. # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
  127. # For 2d weights, the shapes would be [in_size, out_size].
  128. # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
  129. qw_out_size = np.prod(q_weight.shape[1:])
  130. qkv_weight = np.stack((q_weight, k_weight, v_weight), axis=1)
  131. qkv_weight_dim = 3 * int(qw_out_size)
  132. attention_node_name = self.model.create_node_name("Attention")
  133. assert q_bias_shape == k_bias_shape == v_bias_shape
  134. qkv_bias_dim = 0
  135. qkv_bias = np.stack((q_bias, k_bias, v_bias), axis=0)
  136. qkv_bias_dim = 3 * q_bias_shape
  137. self.add_initializer(
  138. name=attention_node_name + "_qkv_weight",
  139. data_type=TensorProto.FLOAT,
  140. dims=[qw_in_size, qkv_weight_dim],
  141. vals=qkv_weight,
  142. )
  143. # No bias, use zeros
  144. qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
  145. qkv_bias_dim = 3 * hidden_size
  146. self.add_initializer(
  147. name=attention_node_name + "_qkv_bias",
  148. data_type=TensorProto.FLOAT,
  149. dims=[qkv_bias_dim],
  150. vals=qkv_bias,
  151. )
  152. attention_inputs = [
  153. input_name,
  154. attention_node_name + "_qkv_weight",
  155. attention_node_name + "_qkv_bias",
  156. ]
  157. attention_node = helper.make_node(
  158. "Attention",
  159. inputs=attention_inputs,
  160. outputs=[output_name],
  161. name=attention_node_name,
  162. )
  163. attention_node.domain = "com.microsoft"
  164. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  165. self.increase_counter("Attention (self attention)")
  166. return attention_node
  167. def fuse(self, softmax_node, input_name_to_nodes, output_name_to_node):
  168. matmul_qkv = self.model.find_first_child_by_type(softmax_node, "MatMul", input_name_to_nodes, recursive=False)
  169. if matmul_qkv is None:
  170. return
  171. reshape_qkv = self.model.find_first_child_by_type(matmul_qkv, "Reshape", input_name_to_nodes, recursive=False)
  172. if reshape_qkv is None:
  173. return
  174. transpose_qkv = self.model.find_first_child_by_type(
  175. reshape_qkv, "Transpose", input_name_to_nodes, recursive=False
  176. )
  177. if transpose_qkv is None:
  178. return
  179. reshape_out = self.model.find_first_child_by_type(
  180. transpose_qkv, "Reshape", input_name_to_nodes, recursive=False
  181. )
  182. if reshape_out is None:
  183. return
  184. matmul_out = self.model.find_first_child_by_type(reshape_out, "MatMul", input_name_to_nodes, recursive=False)
  185. if matmul_out is None:
  186. return
  187. add_out = self.model.find_first_child_by_type(matmul_out, "Add", input_name_to_nodes, recursive=False)
  188. if add_out is None:
  189. return
  190. transpose_out = self.model.find_first_child_by_type(add_out, "Transpose", input_name_to_nodes, recursive=False)
  191. if transpose_out is None:
  192. return
  193. v_nodes = self.model.match_parent_path(
  194. matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
  195. )
  196. if v_nodes is None:
  197. logger.debug("fuse_attention: failed to match v path")
  198. return
  199. (_, _, _, add_v, matmul_v) = v_nodes
  200. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
  201. if qk_nodes is not None:
  202. (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
  203. else:
  204. logger.debug("fuse_attention: failed to match qk path")
  205. return
  206. q_nodes = self.model.match_parent_path(
  207. matmul_qk, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, None]
  208. )
  209. if q_nodes is None:
  210. logger.debug("fuse_attention: failed to match q path")
  211. return
  212. (_, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes
  213. k_nodes = self.model.match_parent_path(
  214. matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
  215. )
  216. if k_nodes is None:
  217. logger.debug("fuse_attention: failed to match k path")
  218. return
  219. (_, _, _, _, add_k, matmul_k) = k_nodes
  220. attention_last_node = reshape_out
  221. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, add_q)
  222. if q_num_heads <= 0:
  223. logger.debug("fuse_attention: failed to detect num_heads")
  224. return
  225. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  226. new_node = self.create_attention_node(
  227. matmul_q,
  228. add_q,
  229. matmul_k,
  230. add_k,
  231. matmul_v,
  232. add_v,
  233. q_num_heads,
  234. q_hidden_size,
  235. matmul_q.input[0],
  236. attention_last_node.output[0],
  237. )
  238. if new_node is None:
  239. return
  240. self.nodes_to_add.append(new_node)
  241. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  242. self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
  243. # Use prune graph to remove nodes since they are shared by all attention nodes.
  244. self.prune_graph = True