fusion_attention_unet.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_base import Fusion
  8. from fusion_utils import NumpyHelper
  9. from onnx import NodeProto, TensorProto, helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionAttentionUnet(Fusion):
  13. """
  14. Fuse Attention subgraph of UNet into one Attention node.
  15. """
  16. def __init__(
  17. self,
  18. model: OnnxModel,
  19. hidden_size: int,
  20. num_heads: int,
  21. is_cross_attention: bool,
  22. enable_packed_qkv: bool,
  23. enable_packed_kv: bool,
  24. ):
  25. super().__init__(
  26. model,
  27. "Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention",
  28. ["LayerNormalization"],
  29. )
  30. self.hidden_size = hidden_size
  31. self.num_heads = num_heads
  32. self.is_cross_attention = is_cross_attention
  33. # Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA.
  34. # To support LoRA, it is better to use separated Q, K and V inputs in offline optimization,
  35. # and CUDA operator pre-packs those tensors to preferred format based on available kernels.
  36. # In this way, we can support LoRA and get optimal performance at same time.
  37. self.enable_packed_qkv = enable_packed_qkv
  38. self.enable_packed_kv = enable_packed_kv
  39. # Flags to show warning only once
  40. self.num_heads_warning = True
  41. self.hidden_size_warning = True
  42. def get_num_heads(self, reshape_q: NodeProto, is_torch2: bool = False) -> int:
  43. """Detect num_heads from a reshape node.
  44. Args:
  45. reshape_q (NodeProto): reshape node for Q
  46. is_torch2 (bool): graph pattern is from PyTorch 2.*
  47. Returns:
  48. int: num_heads, or 0 if not found
  49. """
  50. num_heads = 0
  51. if is_torch2:
  52. # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
  53. reshape_parent = self.model.get_parent(reshape_q, 1)
  54. if reshape_parent and reshape_parent.op_type == "Concat" and len(reshape_parent.input) == 4:
  55. num_heads = self.model.get_constant_value(reshape_parent.input[2])
  56. if isinstance(num_heads, np.ndarray) and list(num_heads.shape) == [1]:
  57. num_heads = int(num_heads)
  58. else:
  59. # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
  60. q_shape_value = self.model.get_constant_value(reshape_q.input[1])
  61. if isinstance(q_shape_value, np.ndarray) and list(q_shape_value.shape) == [4]:
  62. num_heads = int(q_shape_value[2])
  63. if isinstance(num_heads, int) and num_heads > 0:
  64. return num_heads
  65. return 0
  66. def get_hidden_size(self, layernorm_node):
  67. """Detect hidden_size from LayerNormalization node.
  68. Args:
  69. layernorm_node (NodeProto): LayerNormalization node before Q, K and V
  70. Returns:
  71. int: hidden_size, or 0 if not found
  72. """
  73. layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
  74. if layernorm_bias:
  75. return NumpyHelper.to_array(layernorm_bias).shape[0]
  76. return 0
  77. def get_num_heads_and_hidden_size(
  78. self, reshape_q: NodeProto, layernorm_node: NodeProto, is_torch2: bool = False
  79. ) -> tuple[int, int]:
  80. """Detect num_heads and hidden_size.
  81. Args:
  82. reshape_q (NodeProto): reshape node for Q
  83. is_torch2 (bool): graph pattern is from PyTorch 2.*
  84. layernorm_node (NodeProto): LayerNormalization node before Q, K, V
  85. Returns:
  86. Tuple[int, int]: num_heads and hidden_size
  87. """
  88. num_heads = self.get_num_heads(reshape_q, is_torch2)
  89. if num_heads <= 0:
  90. num_heads = self.num_heads # Fall back to user specified value
  91. if self.num_heads > 0 and num_heads != self.num_heads:
  92. if self.num_heads_warning:
  93. logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
  94. self.num_heads_warning = False # Do not show the warning more than once
  95. hidden_size = self.get_hidden_size(layernorm_node)
  96. if hidden_size <= 0:
  97. hidden_size = self.hidden_size # Fall back to user specified value
  98. if self.hidden_size > 0 and hidden_size != self.hidden_size:
  99. if self.hidden_size_warning:
  100. logger.warning(
  101. f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
  102. )
  103. self.hidden_size_warning = False # Do not show the warning more than once
  104. return num_heads, hidden_size
  105. def create_attention_node(
  106. self,
  107. q_matmul: NodeProto,
  108. k_matmul: NodeProto,
  109. v_matmul: NodeProto,
  110. num_heads: int,
  111. hidden_size: int,
  112. input: str,
  113. output: str,
  114. ) -> NodeProto | None:
  115. """Create an Attention node.
  116. Args:
  117. q_matmul (NodeProto): MatMul node in fully connection for Q
  118. k_matmul (NodeProto): MatMul node in fully connection for K
  119. v_matmul (NodeProto): MatMul node in fully connection for V
  120. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  121. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
  122. input (str): input name
  123. output (str): output name
  124. Returns:
  125. Union[NodeProto, None]: the node created or None if failed.
  126. """
  127. is_self_attention = not self.is_cross_attention
  128. if is_self_attention:
  129. if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
  130. logger.debug(
  131. "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
  132. q_matmul.input[0],
  133. k_matmul.input[0],
  134. v_matmul.input[0],
  135. )
  136. return None
  137. else:
  138. if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
  139. logger.debug(
  140. "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
  141. q_matmul.input[0],
  142. k_matmul.input[0],
  143. v_matmul.input[0],
  144. )
  145. return None
  146. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  147. logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
  148. return None
  149. q_weight = self.model.get_initializer(q_matmul.input[1])
  150. k_weight = self.model.get_initializer(k_matmul.input[1])
  151. v_weight = self.model.get_initializer(v_matmul.input[1])
  152. if not (q_weight and k_weight and v_weight):
  153. return None
  154. # Sometimes weights are stored in fp16
  155. float_type = q_weight.data_type
  156. qw = NumpyHelper.to_array(q_weight)
  157. kw = NumpyHelper.to_array(k_weight)
  158. vw = NumpyHelper.to_array(v_weight)
  159. logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
  160. # assert q and k have same shape as expected
  161. if is_self_attention:
  162. if qw.shape != kw.shape or qw.shape != vw.shape:
  163. return None
  164. qw_in_size = qw.shape[0]
  165. if hidden_size > 0 and hidden_size != qw_in_size:
  166. raise ValueError(
  167. f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
  168. "Please provide a correct input hidden size or pass in 0"
  169. )
  170. # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
  171. # For 2d weights, the shapes would be [in_size, out_size].
  172. # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
  173. qw_out_size = int(np.prod(qw.shape[1:]))
  174. if self.enable_packed_qkv:
  175. attention_node_name = self.model.create_node_name("MultiHeadAttention")
  176. c = qw_in_size
  177. n = num_heads
  178. h = qw_out_size // num_heads
  179. # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
  180. qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
  181. c, n * 3 * h
  182. )
  183. matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
  184. self.add_initializer(
  185. name=matmul_node_name + "_weight",
  186. data_type=float_type,
  187. dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
  188. vals=qkv_weight,
  189. )
  190. matmul_node = helper.make_node(
  191. "MatMul",
  192. inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
  193. outputs=[matmul_node_name + "_out"],
  194. name=matmul_node_name,
  195. )
  196. self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
  197. self.add_initializer(
  198. name=matmul_node_name + "_reshape_shape",
  199. data_type=TensorProto.INT64,
  200. dims=[5],
  201. vals=[0, 0, n, 3, h],
  202. raw=False,
  203. )
  204. reshape_node = helper.make_node(
  205. "Reshape",
  206. inputs=[
  207. matmul_node_name + "_out",
  208. matmul_node_name + "_reshape_shape",
  209. ],
  210. outputs=[attention_node_name + "_qkv_input"],
  211. name=matmul_node_name + "_reshape",
  212. )
  213. self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
  214. self.nodes_to_add.extend([matmul_node, reshape_node])
  215. self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul])
  216. else:
  217. qkv_weight = np.stack((qw, kw, vw), axis=1)
  218. qkv_weight_dim = 3 * qw_out_size
  219. attention_node_name = self.model.create_node_name("Attention")
  220. self.add_initializer(
  221. name=attention_node_name + "_qkv_weight",
  222. data_type=float_type,
  223. dims=[qw_in_size, qkv_weight_dim],
  224. vals=qkv_weight,
  225. )
  226. else: # cross attention
  227. attention_node_name = self.model.create_node_name("MultiHeadAttention")
  228. if self.enable_packed_kv:
  229. if kw.shape != vw.shape:
  230. return None
  231. kw_in_size = kw.shape[0]
  232. vw_in_size = vw.shape[0]
  233. assert kw_in_size == vw_in_size
  234. qw_out_size = qw.shape[1]
  235. kw_out_size = kw.shape[1]
  236. vw_out_size = vw.shape[1]
  237. assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
  238. c = kw_in_size
  239. n = num_heads
  240. h = kw_out_size // num_heads
  241. # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
  242. kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
  243. matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
  244. self.add_initializer(
  245. name=matmul_node_name + "_weight",
  246. data_type=float_type,
  247. dims=[kv_weight.shape[0], kv_weight.shape[1]],
  248. vals=kv_weight,
  249. )
  250. matmul_node = helper.make_node(
  251. "MatMul",
  252. inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
  253. outputs=[matmul_node_name + "_out"],
  254. name=matmul_node_name,
  255. )
  256. self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
  257. self.add_initializer(
  258. name=matmul_node_name + "_reshape_shape",
  259. data_type=TensorProto.INT64,
  260. dims=[5],
  261. vals=[0, 0, n, 2, h],
  262. raw=False,
  263. )
  264. reshape_node = helper.make_node(
  265. "Reshape",
  266. inputs=[
  267. matmul_node_name + "_out",
  268. matmul_node_name + "_reshape_shape",
  269. ],
  270. outputs=[attention_node_name + "_kv_input"],
  271. name=matmul_node_name + "_reshape",
  272. )
  273. self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
  274. self.nodes_to_add.extend([matmul_node, reshape_node])
  275. self.nodes_to_remove.extend([k_matmul, v_matmul])
  276. # No bias, use zeros
  277. qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
  278. qkv_bias_dim = 3 * hidden_size
  279. self.add_initializer(
  280. name=attention_node_name + "_qkv_bias",
  281. data_type=float_type,
  282. dims=[qkv_bias_dim],
  283. vals=qkv_bias,
  284. )
  285. if is_self_attention:
  286. if not self.enable_packed_qkv:
  287. attention_inputs = [
  288. input,
  289. attention_node_name + "_qkv_weight",
  290. attention_node_name + "_qkv_bias",
  291. ]
  292. else:
  293. attention_inputs = [attention_node_name + "_qkv_input"]
  294. else:
  295. if not self.enable_packed_kv:
  296. attention_inputs = [
  297. q_matmul.output[0],
  298. k_matmul.output[0],
  299. v_matmul.output[0],
  300. attention_node_name + "_qkv_bias",
  301. ]
  302. else:
  303. attention_inputs = [
  304. q_matmul.output[0],
  305. attention_node_name + "_kv_input",
  306. ]
  307. attention_node = helper.make_node(
  308. "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
  309. inputs=attention_inputs,
  310. outputs=[output],
  311. name=attention_node_name,
  312. )
  313. attention_node.domain = "com.microsoft"
  314. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  315. counter_name = (
  316. "Attention (self attention)"
  317. if is_self_attention and not self.enable_packed_qkv
  318. else "MultiHeadAttention ({})".format(
  319. "self attention with packed qkv"
  320. if self.enable_packed_qkv
  321. else "cross attention with packed kv"
  322. if self.enable_packed_kv
  323. else "cross attention"
  324. )
  325. )
  326. self.increase_counter(counter_name)
  327. return attention_node
  328. def create_attention_node_lora(
  329. self,
  330. q_matmul_add: NodeProto,
  331. k_matmul_add: NodeProto,
  332. v_matmul_add: NodeProto,
  333. num_heads: int,
  334. hidden_size: int,
  335. input: str,
  336. output: str,
  337. ) -> NodeProto | None:
  338. """Create an Attention node.
  339. Args:
  340. q_matmul (NodeProto): MatMul node in fully connection for Q
  341. k_matmul (NodeProto): MatMul node in fully connection for K
  342. v_matmul (NodeProto): MatMul node in fully connection for V
  343. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  344. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
  345. input (str): input name
  346. output (str): output name
  347. Returns:
  348. Union[NodeProto, None]: the node created or None if failed.
  349. """
  350. is_self_attention = not self.is_cross_attention
  351. q_matmul = self.model.match_parent(q_matmul_add, "MatMul", 0)
  352. k_matmul = self.model.match_parent(k_matmul_add, "MatMul", 0)
  353. v_matmul = self.model.match_parent(v_matmul_add, "MatMul", 0)
  354. q_lora_nodes = self.match_lora_path(q_matmul_add)
  355. if q_lora_nodes is None:
  356. return None
  357. (q_lora_last_node, q_lora_matmul_1) = q_lora_nodes
  358. k_lora_nodes = self.match_lora_path(k_matmul_add)
  359. if k_lora_nodes is None:
  360. return None
  361. (k_lora_last_node, k_lora_matmul_1) = k_lora_nodes
  362. v_lora_nodes = self.match_lora_path(v_matmul_add)
  363. if v_lora_nodes is None:
  364. return None
  365. (v_lora_last_node, v_lora_matmul_1) = v_lora_nodes
  366. if is_self_attention:
  367. if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
  368. logger.debug(
  369. "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
  370. q_matmul.input[0],
  371. k_matmul.input[0],
  372. v_matmul.input[0],
  373. )
  374. return None
  375. if (
  376. q_lora_matmul_1.input[0] != input
  377. or k_lora_matmul_1.input[0] != input
  378. or v_lora_matmul_1.input[0] != input
  379. ):
  380. logger.debug(
  381. "For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %s",
  382. q_lora_matmul_1.input[0],
  383. k_lora_matmul_1.input[0],
  384. v_lora_matmul_1.input[0],
  385. )
  386. return None
  387. else:
  388. if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
  389. logger.debug(
  390. "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
  391. q_matmul.input[0],
  392. k_matmul.input[0],
  393. v_matmul.input[0],
  394. )
  395. return None
  396. if (
  397. q_lora_matmul_1.input[0] != input
  398. or (k_lora_matmul_1.input[0] != v_lora_matmul_1.input[0])
  399. or (k_matmul.input[0] == input)
  400. ):
  401. logger.debug(
  402. (
  403. "For cross attention, input hidden state for LoRA q and k/v weights shall be different. "
  404. "Got %s, %s, %s"
  405. ),
  406. q_lora_matmul_1.input[0],
  407. k_lora_matmul_1.input[0],
  408. v_lora_matmul_1.input[0],
  409. )
  410. return None
  411. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  412. logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
  413. return None
  414. q_weight = self.model.get_initializer(q_matmul.input[1])
  415. k_weight = self.model.get_initializer(k_matmul.input[1])
  416. v_weight = self.model.get_initializer(v_matmul.input[1])
  417. if not (q_weight and k_weight and v_weight):
  418. return None
  419. # Sometimes weights are stored in fp16
  420. if q_weight.data_type == 10:
  421. logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
  422. return None
  423. qw = NumpyHelper.to_array(q_weight)
  424. kw = NumpyHelper.to_array(k_weight)
  425. vw = NumpyHelper.to_array(v_weight)
  426. logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
  427. # assert q and k have same shape as expected
  428. if is_self_attention:
  429. if qw.shape != kw.shape or qw.shape != vw.shape:
  430. return None
  431. qw_in_size = qw.shape[0]
  432. if hidden_size > 0 and hidden_size != qw_in_size:
  433. raise ValueError(
  434. f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
  435. "Please provide a correct input hidden size or pass in 0"
  436. )
  437. # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
  438. # For 2d weights, the shapes would be [in_size, out_size].
  439. # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
  440. qw_out_size = int(np.prod(qw.shape[1:]))
  441. if self.enable_packed_qkv:
  442. attention_node_name = self.model.create_node_name("MultiHeadAttention")
  443. c = qw_in_size
  444. n = num_heads
  445. h = qw_out_size // num_heads
  446. # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
  447. qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
  448. c, n * 3 * h
  449. )
  450. matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
  451. self.add_initializer(
  452. name=matmul_node_name + "_weight",
  453. data_type=TensorProto.FLOAT,
  454. dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
  455. vals=qkv_weight,
  456. )
  457. matmul_node = helper.make_node(
  458. "MatMul",
  459. inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
  460. outputs=[matmul_node_name + "_out"],
  461. name=matmul_node_name,
  462. )
  463. self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
  464. # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
  465. # the Q/K/V weights to be changed without having to re-run the optimizer.
  466. lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
  467. self.add_initializer(
  468. name=lora_weight_shape_tensor_name,
  469. data_type=TensorProto.INT64,
  470. dims=[4],
  471. vals=[0, 0, n, h],
  472. raw=False,
  473. )
  474. # Reshape the LoRA Q weights
  475. q_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_Q")
  476. q_lora_reshape_node = helper.make_node(
  477. "Reshape",
  478. inputs=[q_lora_last_node.output[0], lora_weight_shape_tensor_name],
  479. outputs=[q_lora_reshape_node_name + "_out"],
  480. name=q_lora_reshape_node_name,
  481. )
  482. self.node_name_to_graph_name[q_lora_reshape_node.name] = self.this_graph_name
  483. # Reshape the LoRA K weights
  484. k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
  485. k_lora_reshape_node = helper.make_node(
  486. "Reshape",
  487. inputs=[k_lora_last_node.output[0], lora_weight_shape_tensor_name],
  488. outputs=[k_lora_reshape_node_name + "_out"],
  489. name=k_lora_reshape_node_name,
  490. )
  491. self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
  492. # Reshape the LoRA V weights
  493. v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
  494. v_lora_reshape_node = helper.make_node(
  495. "Reshape",
  496. inputs=[v_lora_last_node.output[0], lora_weight_shape_tensor_name],
  497. outputs=[v_lora_reshape_node_name + "_out"],
  498. name=v_lora_reshape_node_name,
  499. )
  500. self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
  501. # Concat the reshaped LoRA Q/K/V weights together on the third axis
  502. qkv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_QKV")
  503. qkv_lora_concat_node = helper.make_node(
  504. "Concat",
  505. inputs=[
  506. q_lora_reshape_node.output[0],
  507. k_lora_reshape_node.output[0],
  508. v_lora_reshape_node.output[0],
  509. ],
  510. outputs=[qkv_lora_concat_node_name + "_out"],
  511. name=qkv_lora_concat_node_name,
  512. )
  513. qkv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
  514. self.node_name_to_graph_name[qkv_lora_concat_node.name] = self.this_graph_name
  515. # Reshape the LoRA concatenated weights to [..., n * 3 * h]
  516. reshaped_lora_weights_shape_tensor_name = qkv_lora_concat_node.name + "_reshape_shape"
  517. self.add_initializer(
  518. name=reshaped_lora_weights_shape_tensor_name,
  519. data_type=TensorProto.INT64,
  520. dims=[3],
  521. vals=[0, 0, n * 3 * h],
  522. raw=False,
  523. )
  524. qkv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_QKV")
  525. qkv_lora_reshaped_node = helper.make_node(
  526. "Reshape",
  527. inputs=[qkv_lora_concat_node.output[0], reshaped_lora_weights_shape_tensor_name],
  528. outputs=[qkv_lora_reshaped_node_name + "_out"],
  529. name=qkv_lora_reshaped_node_name,
  530. )
  531. self.node_name_to_graph_name[qkv_lora_reshaped_node.name] = self.this_graph_name
  532. # Add the LoRA Q/K/V weights to the base Q/K/V weights
  533. add_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_QKV")
  534. add_weights_node = helper.make_node(
  535. "Add",
  536. inputs=[qkv_lora_reshaped_node.output[0], matmul_node.output[0]],
  537. outputs=[add_weights_node_name + "_out"],
  538. name=add_weights_node_name,
  539. )
  540. self.node_name_to_graph_name[add_weights_node.name] = self.this_graph_name
  541. # Finally, reshape the concatenated Q/K/V result to 5D
  542. shape_tensor_name = add_weights_node_name + "_reshape_shape"
  543. self.add_initializer(
  544. name=shape_tensor_name,
  545. data_type=TensorProto.INT64,
  546. dims=[5],
  547. vals=[0, 0, n, 3, h],
  548. raw=False,
  549. )
  550. reshape_node = helper.make_node(
  551. "Reshape",
  552. inputs=[add_weights_node.output[0], shape_tensor_name],
  553. outputs=[attention_node_name + "_qkv_input"],
  554. name=add_weights_node_name + "_reshape",
  555. )
  556. self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
  557. self.nodes_to_add.extend(
  558. [
  559. matmul_node,
  560. q_lora_reshape_node,
  561. k_lora_reshape_node,
  562. v_lora_reshape_node,
  563. qkv_lora_concat_node,
  564. qkv_lora_reshaped_node,
  565. add_weights_node,
  566. reshape_node,
  567. ]
  568. )
  569. self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul, q_matmul_add, k_matmul_add, v_matmul_add])
  570. else:
  571. # TODO: Support non-packed QKV
  572. return None
  573. else: # cross attention
  574. attention_node_name = self.model.create_node_name("MultiHeadAttention")
  575. if self.enable_packed_kv:
  576. if kw.shape != vw.shape:
  577. return None
  578. kw_in_size = kw.shape[0]
  579. vw_in_size = vw.shape[0]
  580. assert kw_in_size == vw_in_size
  581. qw_out_size = qw.shape[1]
  582. kw_out_size = kw.shape[1]
  583. vw_out_size = vw.shape[1]
  584. assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
  585. c = kw_in_size
  586. n = num_heads
  587. h = kw_out_size // num_heads
  588. # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
  589. kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
  590. matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
  591. self.add_initializer(
  592. name=matmul_node_name + "_weight",
  593. data_type=TensorProto.FLOAT,
  594. dims=[kv_weight.shape[0], kv_weight.shape[1]],
  595. vals=kv_weight,
  596. )
  597. matmul_node = helper.make_node(
  598. "MatMul",
  599. inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
  600. outputs=[matmul_node_name + "_out"],
  601. name=matmul_node_name,
  602. )
  603. self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
  604. # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
  605. # the Q/K/V weights to be changed without having to re-run the optimizer.
  606. kv_lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
  607. self.add_initializer(
  608. name=kv_lora_weight_shape_tensor_name,
  609. data_type=TensorProto.INT64,
  610. dims=[4],
  611. vals=[0, 0, n, h],
  612. raw=False,
  613. )
  614. # Reshape the LoRA K weights
  615. k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
  616. k_lora_reshape_node = helper.make_node(
  617. "Reshape",
  618. inputs=[k_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
  619. outputs=[k_lora_reshape_node_name + "_out"],
  620. name=k_lora_reshape_node_name,
  621. )
  622. self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
  623. # Reshape the LoRA V weights
  624. v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
  625. v_lora_reshape_node = helper.make_node(
  626. "Reshape",
  627. inputs=[v_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
  628. outputs=[v_lora_reshape_node_name + "_out"],
  629. name=v_lora_reshape_node_name,
  630. )
  631. self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
  632. # Concat the reshaped LoRA K/V weights together on the third axis
  633. kv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_KV")
  634. kv_lora_concat_node = helper.make_node(
  635. "Concat",
  636. inputs=[k_lora_reshape_node.output[0], v_lora_reshape_node.output[0]],
  637. outputs=[kv_lora_concat_node_name + "_out"],
  638. name=kv_lora_concat_node_name,
  639. )
  640. kv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
  641. self.node_name_to_graph_name[kv_lora_concat_node.name] = self.this_graph_name
  642. # Reshape the LoRA concatenated weights to [..., n * 2 * h]
  643. reshaped_kv_lora_weights_shape_tensor_name = kv_lora_concat_node.name + "_reshape_shape"
  644. self.add_initializer(
  645. name=reshaped_kv_lora_weights_shape_tensor_name,
  646. data_type=TensorProto.INT64,
  647. dims=[3],
  648. vals=[0, 0, n * 2 * h],
  649. raw=False,
  650. )
  651. kv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_KV")
  652. kv_lora_reshaped_node = helper.make_node(
  653. "Reshape",
  654. inputs=[kv_lora_concat_node.output[0], reshaped_kv_lora_weights_shape_tensor_name],
  655. outputs=[kv_lora_reshaped_node_name + "_out"],
  656. name=kv_lora_reshaped_node_name,
  657. )
  658. self.node_name_to_graph_name[kv_lora_reshaped_node.name] = self.this_graph_name
  659. # Add the LoRA K/V weights to the base K/V weights
  660. add_kv_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_KV")
  661. add_kv_weights_node = helper.make_node(
  662. "Add",
  663. inputs=[kv_lora_reshaped_node.output[0], matmul_node.output[0]],
  664. outputs=[add_kv_weights_node_name + "_out"],
  665. name=add_kv_weights_node_name,
  666. )
  667. self.node_name_to_graph_name[add_kv_weights_node.name] = self.this_graph_name
  668. # Finally, reshape the concatenated K/V result to 5D
  669. shape_tensor_name = add_kv_weights_node_name + "_reshape_shape"
  670. self.add_initializer(
  671. name=shape_tensor_name,
  672. data_type=TensorProto.INT64,
  673. dims=[5],
  674. vals=[0, 0, n, 2, h],
  675. raw=False,
  676. )
  677. reshape_node = helper.make_node(
  678. "Reshape",
  679. inputs=[add_kv_weights_node.output[0], shape_tensor_name],
  680. outputs=[attention_node_name + "_kv_input"],
  681. name=add_kv_weights_node_name + "_reshape",
  682. )
  683. self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
  684. self.nodes_to_add.extend(
  685. [
  686. matmul_node,
  687. k_lora_reshape_node,
  688. v_lora_reshape_node,
  689. kv_lora_concat_node,
  690. kv_lora_reshaped_node,
  691. add_kv_weights_node,
  692. reshape_node,
  693. ]
  694. )
  695. self.nodes_to_remove.extend([k_matmul, v_matmul, k_matmul_add, v_matmul_add])
  696. else:
  697. # TODO: Support non-packed KV
  698. return None
  699. # No bias, use zeros
  700. qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
  701. qkv_bias_dim = 3 * hidden_size
  702. self.add_initializer(
  703. name=attention_node_name + "_qkv_bias",
  704. data_type=TensorProto.FLOAT,
  705. dims=[qkv_bias_dim],
  706. vals=qkv_bias,
  707. )
  708. if is_self_attention:
  709. if not self.enable_packed_qkv:
  710. # TODO: Support non-packed QKV
  711. return None
  712. else:
  713. attention_inputs = [attention_node_name + "_qkv_input"]
  714. else:
  715. if not self.enable_packed_kv:
  716. # TODO: Support non-packed QKV
  717. return None
  718. else:
  719. attention_inputs = [
  720. q_matmul_add.output[0],
  721. attention_node_name + "_kv_input",
  722. ]
  723. attention_node = helper.make_node(
  724. "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
  725. inputs=attention_inputs,
  726. outputs=[output],
  727. name=attention_node_name,
  728. )
  729. attention_node.domain = "com.microsoft"
  730. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  731. counter_name = (
  732. "Attention (self attention)"
  733. if is_self_attention and not self.enable_packed_qkv
  734. else "MultiHeadAttention ({})".format(
  735. "self attention with packed qkv"
  736. if self.enable_packed_qkv
  737. else "cross attention with packed kv"
  738. if self.enable_packed_kv
  739. else "cross attention"
  740. )
  741. )
  742. self.increase_counter(counter_name)
  743. return attention_node
  744. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  745. if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node):
  746. return
  747. node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
  748. # In SD 1.5, for self attention, LayerNorm has parent Reshape
  749. if node_before_layernorm is None and not self.is_cross_attention:
  750. node_before_layernorm = self.model.match_parent(normalize_node, "Reshape", 0)
  751. if node_before_layernorm is None:
  752. return
  753. root_input = node_before_layernorm.output[0]
  754. children_nodes = input_name_to_nodes[root_input]
  755. skip_add = None
  756. for node in children_nodes:
  757. if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
  758. skip_add = node
  759. break
  760. if skip_add is None:
  761. return
  762. match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add)
  763. if match_qkv is not None:
  764. is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv
  765. attention_last_node = reshape_qkv
  766. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
  767. if q_num_heads <= 0:
  768. logger.debug("fuse_attention: failed to detect num_heads")
  769. return
  770. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  771. new_node = self.create_attention_node(
  772. matmul_q,
  773. matmul_k,
  774. matmul_v,
  775. q_num_heads,
  776. q_hidden_size,
  777. input=normalize_node.output[0],
  778. output=attention_last_node.output[0],
  779. )
  780. if new_node is None:
  781. return
  782. else:
  783. # Check if we have a LoRA pattern
  784. match_qkv = self.match_qkv_torch1_lora(root_input, skip_add) or self.match_qkv_torch2_lora(
  785. root_input, skip_add
  786. )
  787. if match_qkv is None:
  788. return
  789. is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v = match_qkv
  790. attention_last_node = reshape_qkv
  791. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
  792. if q_num_heads <= 0:
  793. logger.debug("fuse_attention: failed to detect num_heads")
  794. return
  795. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  796. new_node = self.create_attention_node_lora(
  797. matmul_add_q,
  798. matmul_add_k,
  799. matmul_add_v,
  800. q_num_heads,
  801. q_hidden_size,
  802. input=normalize_node.output[0],
  803. output=attention_last_node.output[0],
  804. )
  805. if new_node is None:
  806. return
  807. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
  808. if q_num_heads <= 0:
  809. logger.debug("fuse_attention: failed to detect num_heads")
  810. return
  811. self.nodes_to_add.append(new_node)
  812. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  813. self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
  814. # Use prune graph to remove nodes since they are shared by all attention nodes.
  815. self.prune_graph = True
  816. def match_qkv_torch1(self, root_input, skip_add):
  817. """Match Q, K and V paths exported by PyTorch 1.*"""
  818. another_input = 1 if skip_add.input[0] == root_input else 0
  819. qkv_nodes = self.model.match_parent_path(
  820. skip_add,
  821. ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  822. [another_input, None, None, 0, 0, 0],
  823. )
  824. if qkv_nodes is None:
  825. return None
  826. (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
  827. # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
  828. v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
  829. if v_nodes is None:
  830. logger.debug("fuse_attention: failed to match v path")
  831. return None
  832. (_, _, _, matmul_v) = v_nodes
  833. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
  834. if qk_nodes is not None:
  835. (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
  836. else:
  837. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
  838. if qk_nodes is not None:
  839. (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
  840. else:
  841. logger.debug("fuse_attention: failed to match qk path")
  842. return None
  843. q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
  844. if q_nodes is None:
  845. logger.debug("fuse_attention: failed to match q path")
  846. return None
  847. (_, _transpose_q, reshape_q, matmul_q) = q_nodes
  848. k_nodes = self.model.match_parent_path(
  849. matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
  850. )
  851. if k_nodes is None:
  852. logger.debug("fuse_attention: failed to match k path")
  853. return None
  854. (_, _, _, _, matmul_k) = k_nodes
  855. return False, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
  856. def match_qkv_torch2(self, root_input, skip_add):
  857. """Match Q, K and V paths exported by PyTorch 2.*"""
  858. another_input = 1 if skip_add.input[0] == root_input else 0
  859. qkv_nodes = self.model.match_parent_path(
  860. skip_add,
  861. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  862. [another_input, None, None, 0, 0],
  863. )
  864. if qkv_nodes is None:
  865. return None
  866. (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
  867. v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0])
  868. if v_nodes is None:
  869. logger.debug("fuse_attention: failed to match v path")
  870. return None
  871. (_, _, matmul_v) = v_nodes
  872. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
  873. if qk_nodes is not None:
  874. (_softmax_qk, matmul_qk) = qk_nodes
  875. else:
  876. logger.debug("fuse_attention: failed to match qk path")
  877. return None
  878. q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0])
  879. if q_nodes is None:
  880. logger.debug("fuse_attention: failed to match q path")
  881. return None
  882. (mul_q, _transpose_q, reshape_q, matmul_q) = q_nodes
  883. k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0])
  884. if k_nodes is None:
  885. logger.debug("fuse_attention: failed to match k path")
  886. return None
  887. (_mul_k, _, _, matmul_k) = k_nodes
  888. # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
  889. mul_q_nodes = self.model.match_parent_path(
  890. mul_q,
  891. ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
  892. [None, 0, 1, 0, 0, 0, 0, 0],
  893. )
  894. if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
  895. logger.debug("fuse_attention: failed to match mul_q path")
  896. return None
  897. return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
  898. def match_qkv_torch1_lora(self, root_input, skip_add):
  899. """Match Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*"""
  900. another_input = 1 if skip_add.input[0] == root_input else 0
  901. qkv_nodes = self.model.match_parent_path(
  902. skip_add,
  903. ["Add", "Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  904. [another_input, 0, None, None, 0, 0, 0],
  905. )
  906. if qkv_nodes is None:
  907. return None
  908. (_, _, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
  909. # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
  910. v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0])
  911. if v_nodes is None:
  912. logger.debug("fuse_attention: failed to match LoRA v path")
  913. return None
  914. (_, _, _, matmul_add_v) = v_nodes
  915. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
  916. if qk_nodes is not None:
  917. (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
  918. else:
  919. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
  920. if qk_nodes is not None:
  921. (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
  922. else:
  923. logger.debug("fuse_attention: failed to match LoRA qk path")
  924. return None
  925. q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "Add"], [0, 0, 0, 0])
  926. if q_nodes is None:
  927. logger.debug("fuse_attention: failed to match LoRA q path")
  928. return None
  929. (_, _transpose_q, reshape_q, matmul_add_q) = q_nodes
  930. k_nodes = self.model.match_parent_path(
  931. matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0, 0]
  932. )
  933. if k_nodes is None:
  934. logger.debug("fuse_attention: failed to match LoRA k path")
  935. return None
  936. (_, _, _, _, matmul_add_k) = k_nodes
  937. return False, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
  938. def match_qkv_torch2_lora(self, root_input, skip_add):
  939. """Match Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*"""
  940. another_input = 1 if skip_add.input[0] == root_input else 0
  941. qkv_nodes = self.model.match_parent_path(
  942. skip_add,
  943. ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  944. [another_input, 0, None, None, 0, 0],
  945. )
  946. if qkv_nodes is None:
  947. return None
  948. (_, _, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
  949. v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add"], [1, 0, 0])
  950. if v_nodes is None:
  951. logger.debug("fuse_attention: failed to match LoRA v path")
  952. return None
  953. (_, _, matmul_add_v) = v_nodes
  954. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
  955. if qk_nodes is not None:
  956. (_softmax_qk, matmul_qk) = qk_nodes
  957. else:
  958. logger.debug("fuse_attention: failed to match LoRA qk path")
  959. return None
  960. q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [0, None, 0, 0])
  961. if q_nodes is None:
  962. logger.debug("fuse_attention: failed to match LoRA q path")
  963. return None
  964. (mul_q, _transpose_q, reshape_q, matmul_add_q) = q_nodes
  965. k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [1, None, 0, 0])
  966. if k_nodes is None:
  967. logger.debug("fuse_attention: failed to match LoRA k path")
  968. return None
  969. (_mul_k, _, _, matmul_add_k) = k_nodes
  970. # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
  971. mul_q_nodes = self.model.match_parent_path(
  972. mul_q,
  973. ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
  974. [None, 0, 1, 0, 0, 0, 0, 0],
  975. )
  976. if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
  977. logger.debug("fuse_attention: failed to match LoRA mul_q path")
  978. return None
  979. return True, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
  980. def match_lora_path(
  981. self,
  982. add_node: NodeProto,
  983. ):
  984. # Lora paths can look like one of the following options:
  985. # MatMul -> MatMul -> Add
  986. # MatMul -> MatMul -> Mul -> Add
  987. # MatMul -> MatMul -> Mul -> Mul -> Add
  988. # Try matching MatMul -> MatMul -> Add
  989. lora_nodes = self.model.match_parent_path(
  990. add_node,
  991. ["MatMul", "MatMul"],
  992. [1, 0],
  993. )
  994. if lora_nodes is not None:
  995. (lora_matmul_2_node, lora_matmul_1_node) = lora_nodes
  996. return (lora_matmul_2_node, lora_matmul_1_node)
  997. # Try matching MatMul -> MatMul -> Mul -> Add
  998. lora_nodes = self.model.match_parent_path(
  999. add_node,
  1000. ["Mul", "MatMul", "MatMul"],
  1001. [1, 0, 0],
  1002. )
  1003. if lora_nodes is not None:
  1004. (lora_mul_node, _, lora_matmul_1_node) = lora_nodes
  1005. return (lora_mul_node, lora_matmul_1_node)
  1006. # Try matching MatMul -> MatMul -> Mul -> Mul -> Add
  1007. lora_nodes = self.model.match_parent_path(
  1008. add_node,
  1009. ["Mul", "Mul", "MatMul", "MatMul"],
  1010. [1, 0, 0, 0],
  1011. )
  1012. if lora_nodes is not None:
  1013. (lora_mul_node, _, _, lora_matmul_1_node) = lora_nodes
  1014. return (lora_mul_node, lora_matmul_1_node)
  1015. return None
  1016. def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node):
  1017. """Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension"""
  1018. entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0])
  1019. if entry_path is None:
  1020. entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0])
  1021. if entry_path is None:
  1022. return False
  1023. _cast, node_before_layernorm = entry_path
  1024. root_input = node_before_layernorm.output[0]
  1025. children_nodes = input_name_to_nodes[root_input]
  1026. skip_add = None
  1027. for node in children_nodes:
  1028. if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
  1029. skip_add = node
  1030. break
  1031. if skip_add is None:
  1032. return False
  1033. match_qkv = self.match_qkv_a1111(root_input, skip_add)
  1034. if match_qkv is None:
  1035. return False
  1036. (
  1037. reshape_qkv,
  1038. transpose_qkv,
  1039. reshape_q,
  1040. matmul_q,
  1041. matmul_k,
  1042. matmul_v,
  1043. ) = match_qkv
  1044. cast_q = self.model.match_parent(matmul_q, "Cast", 0)
  1045. cast_k = self.model.match_parent(matmul_k, "Cast", 0)
  1046. cast_v = self.model.match_parent(matmul_v, "Cast", 0)
  1047. if not (
  1048. cast_q is not None
  1049. and cast_k is not None
  1050. and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k)
  1051. and cast_k == cast_v
  1052. ):
  1053. return False
  1054. if cast_q.input[0] != normalize_node.output[0]:
  1055. return False
  1056. attention_last_node = reshape_qkv
  1057. q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False)
  1058. if q_num_heads <= 0:
  1059. logger.debug("fuse_attention: failed to detect num_heads")
  1060. return False
  1061. q_hidden_size = self.get_hidden_size(normalize_node)
  1062. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  1063. new_node = self.create_attention_node(
  1064. matmul_q,
  1065. matmul_k,
  1066. matmul_v,
  1067. q_num_heads,
  1068. q_hidden_size,
  1069. input=matmul_q.input[0],
  1070. output=attention_last_node.output[0],
  1071. )
  1072. if new_node is None:
  1073. return False
  1074. self.nodes_to_add.append(new_node)
  1075. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  1076. self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
  1077. # Use prune graph to remove nodes since they are shared by all attention nodes.
  1078. self.prune_graph = True
  1079. return True
  1080. def match_qkv_a1111(self, root_input, skip_add):
  1081. """Match Q, K and V paths exported by A1111 (stable diffusion webui) extension"""
  1082. another_input = 1 if skip_add.input[0] == root_input else 0
  1083. qkv_nodes = self.model.match_parent_path(
  1084. skip_add,
  1085. ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"],
  1086. [another_input, None, None, 0, 0, 0],
  1087. )
  1088. if qkv_nodes is None:
  1089. return None
  1090. (_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes
  1091. v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
  1092. if v_nodes is None:
  1093. logger.debug("fuse_attention: failed to match v path")
  1094. return None
  1095. (_, _, _, matmul_v) = v_nodes
  1096. qk_nodes = self.model.match_parent_path(
  1097. einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None]
  1098. )
  1099. if qk_nodes is not None:
  1100. (_, _, _softmax_qk, _, einsum_qk) = qk_nodes
  1101. else:
  1102. logger.debug("fuse_attention: failed to match qk path")
  1103. return None
  1104. q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
  1105. if q_nodes is None:
  1106. logger.debug("fuse_attention: failed to match q path")
  1107. return None
  1108. (_, _transpose_q, reshape_q, matmul_q) = q_nodes
  1109. k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
  1110. if k_nodes is None:
  1111. logger.debug("fuse_attention: failed to match k path")
  1112. return None
  1113. (_, _, _, matmul_k) = k_nodes
  1114. return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v