fusion_options.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from argparse import ArgumentParser
  6. from enum import Enum
  7. class AttentionMaskFormat:
  8. # Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance.
  9. MaskIndexEnd = 0
  10. # For experiment only. Do not use it in production.
  11. MaskIndexEndAndStart = 1
  12. # Raw attention mask with 0 means padding (or no attention) and 1 otherwise.
  13. AttentionMask = 2
  14. # No attention mask
  15. NoMask = 3
  16. class AttentionOpType(Enum):
  17. Attention = "Attention"
  18. MultiHeadAttention = "MultiHeadAttention"
  19. GroupQueryAttention = "GroupQueryAttention"
  20. PagedAttention = "PagedAttention"
  21. def __str__(self):
  22. return self.value
  23. # Override __eq__ to return string comparison
  24. def __hash__(self):
  25. return hash(self.value)
  26. def __eq__(self, other):
  27. return other.value == self.value
  28. class FusionOptions:
  29. """Options of fusion in graph optimization"""
  30. def __init__(self, model_type):
  31. self.enable_gelu = True
  32. self.enable_layer_norm = True
  33. self.enable_attention = True
  34. self.enable_rotary_embeddings = True
  35. # Use MultiHeadAttention instead of Attention operator. The difference:
  36. # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is
  37. # merged into one.
  38. # (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention.
  39. self.use_multi_head_attention = False
  40. self.disable_multi_head_attention_bias = False
  41. self.enable_skip_layer_norm = True
  42. self.enable_embed_layer_norm = True
  43. self.enable_bias_skip_layer_norm = True
  44. self.enable_bias_gelu = True
  45. self.enable_gelu_approximation = False
  46. self.enable_qordered_matmul = True
  47. self.enable_shape_inference = True
  48. self.enable_gemm_fast_gelu = False
  49. self.group_norm_channels_last = True
  50. if model_type == "clip":
  51. self.enable_embed_layer_norm = False
  52. # Set default to sequence length for BERT model to use fused attention to speed up.
  53. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
  54. self.attention_mask_format = AttentionMaskFormat.AttentionMask
  55. if model_type == "bert":
  56. self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
  57. elif model_type == "vit":
  58. self.attention_mask_format = AttentionMaskFormat.NoMask
  59. self.attention_op_type = None
  60. # options for stable diffusion
  61. if model_type in ["unet", "vae", "clip"]:
  62. self.enable_nhwc_conv = True
  63. self.enable_group_norm = True
  64. self.enable_skip_group_norm = True
  65. self.enable_bias_splitgelu = True
  66. self.enable_packed_qkv = True
  67. self.enable_packed_kv = True
  68. self.enable_bias_add = True
  69. def use_raw_attention_mask(self, use_raw_mask=True):
  70. if use_raw_mask:
  71. self.attention_mask_format = AttentionMaskFormat.AttentionMask
  72. else:
  73. self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
  74. def disable_attention_mask(self):
  75. self.attention_mask_format = AttentionMaskFormat.NoMask
  76. def set_attention_op_type(self, attn_op_type: AttentionOpType):
  77. self.attention_op_type = attn_op_type
  78. @staticmethod
  79. def parse(args):
  80. options = FusionOptions(args.model_type)
  81. if args.disable_gelu:
  82. options.enable_gelu = False
  83. if args.disable_layer_norm:
  84. options.enable_layer_norm = False
  85. if args.disable_rotary_embeddings:
  86. options.enable_rotary_embeddings = False
  87. if args.disable_attention:
  88. options.enable_attention = False
  89. if args.use_multi_head_attention:
  90. options.use_multi_head_attention = True
  91. if args.disable_skip_layer_norm:
  92. options.enable_skip_layer_norm = False
  93. if args.disable_embed_layer_norm:
  94. options.enable_embed_layer_norm = False
  95. if args.disable_bias_skip_layer_norm:
  96. options.enable_bias_skip_layer_norm = False
  97. if args.disable_bias_gelu:
  98. options.enable_bias_gelu = False
  99. if args.enable_gelu_approximation:
  100. options.enable_gelu_approximation = True
  101. if args.disable_shape_inference:
  102. options.enable_shape_inference = False
  103. if args.enable_gemm_fast_gelu:
  104. options.enable_gemm_fast_gelu = True
  105. if args.use_mask_index:
  106. options.use_raw_attention_mask(False)
  107. if args.use_raw_attention_mask:
  108. options.use_raw_attention_mask(True)
  109. if args.no_attention_mask:
  110. options.disable_attention_mask()
  111. if args.model_type in ["unet", "vae", "clip"]:
  112. if args.use_group_norm_channels_first:
  113. options.group_norm_channels_last = False
  114. if args.disable_nhwc_conv:
  115. options.enable_nhwc_conv = False
  116. if args.disable_group_norm:
  117. options.enable_group_norm = False
  118. if args.disable_skip_group_norm:
  119. options.enable_skip_group_norm = False
  120. if args.disable_bias_splitgelu:
  121. options.enable_bias_splitgelu = False
  122. if args.disable_packed_qkv:
  123. options.enable_packed_qkv = False
  124. if args.disable_packed_kv:
  125. options.enable_packed_kv = False
  126. if args.disable_bias_add:
  127. options.enable_bias_add = False
  128. return options
  129. @staticmethod
  130. def add_arguments(parser: ArgumentParser):
  131. parser.add_argument(
  132. "--disable_attention",
  133. required=False,
  134. action="store_true",
  135. help="disable Attention fusion",
  136. )
  137. parser.set_defaults(disable_attention=False)
  138. parser.add_argument(
  139. "--disable_skip_layer_norm",
  140. required=False,
  141. action="store_true",
  142. help="disable SkipLayerNormalization fusion",
  143. )
  144. parser.set_defaults(disable_skip_layer_norm=False)
  145. parser.add_argument(
  146. "--disable_embed_layer_norm",
  147. required=False,
  148. action="store_true",
  149. help="disable EmbedLayerNormalization fusion",
  150. )
  151. parser.set_defaults(disable_embed_layer_norm=False)
  152. parser.add_argument(
  153. "--disable_bias_skip_layer_norm",
  154. required=False,
  155. action="store_true",
  156. help="disable Add Bias and SkipLayerNormalization fusion",
  157. )
  158. parser.set_defaults(disable_bias_skip_layer_norm=False)
  159. parser.add_argument(
  160. "--disable_bias_gelu",
  161. required=False,
  162. action="store_true",
  163. help="disable Add Bias and Gelu/FastGelu fusion",
  164. )
  165. parser.set_defaults(disable_bias_gelu=False)
  166. parser.add_argument(
  167. "--disable_layer_norm",
  168. required=False,
  169. action="store_true",
  170. help="disable LayerNormalization fusion",
  171. )
  172. parser.set_defaults(disable_layer_norm=False)
  173. parser.add_argument(
  174. "--disable_gelu",
  175. required=False,
  176. action="store_true",
  177. help="disable Gelu fusion",
  178. )
  179. parser.set_defaults(disable_gelu=False)
  180. parser.add_argument(
  181. "--enable_gelu_approximation",
  182. required=False,
  183. action="store_true",
  184. help="enable Gelu/BiasGelu to FastGelu conversion",
  185. )
  186. parser.set_defaults(enable_gelu_approximation=False)
  187. parser.add_argument(
  188. "--disable_shape_inference",
  189. required=False,
  190. action="store_true",
  191. help="disable symbolic shape inference",
  192. )
  193. parser.set_defaults(disable_shape_inference=False)
  194. parser.add_argument(
  195. "--enable_gemm_fast_gelu",
  196. required=False,
  197. action="store_true",
  198. help="enable GemmfastGelu fusion",
  199. )
  200. parser.set_defaults(enable_gemm_fast_gelu=False)
  201. parser.add_argument(
  202. "--use_mask_index",
  203. required=False,
  204. action="store_true",
  205. help="use mask index to activate fused attention to speed up. It requires right-side padding!",
  206. )
  207. parser.set_defaults(use_mask_index=False)
  208. parser.add_argument(
  209. "--use_raw_attention_mask",
  210. required=False,
  211. action="store_true",
  212. help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.",
  213. )
  214. parser.set_defaults(use_raw_attention_mask=False)
  215. parser.add_argument(
  216. "--no_attention_mask",
  217. required=False,
  218. action="store_true",
  219. help="no attention mask. Only works for model_type=bert",
  220. )
  221. parser.set_defaults(no_attention_mask=False)
  222. parser.add_argument(
  223. "--use_multi_head_attention",
  224. required=False,
  225. action="store_true",
  226. help="Use MultiHeadAttention instead of Attention operator for testing purpose. "
  227. "Note that MultiHeadAttention might be slower than Attention when qkv are not packed. ",
  228. )
  229. parser.set_defaults(use_multi_head_attention=False)
  230. parser.add_argument(
  231. "--disable_group_norm",
  232. required=False,
  233. action="store_true",
  234. help="not fuse GroupNorm. Only works for model_type=unet or vae",
  235. )
  236. parser.set_defaults(disable_group_norm=False)
  237. parser.add_argument(
  238. "--disable_skip_group_norm",
  239. required=False,
  240. action="store_true",
  241. help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae",
  242. )
  243. parser.set_defaults(disable_skip_group_norm=False)
  244. parser.add_argument(
  245. "--disable_packed_kv",
  246. required=False,
  247. action="store_true",
  248. help="not use packed kv for cross attention in MultiHeadAttention. Only works for model_type=unet",
  249. )
  250. parser.set_defaults(disable_packed_kv=False)
  251. parser.add_argument(
  252. "--disable_packed_qkv",
  253. required=False,
  254. action="store_true",
  255. help="not use packed qkv for self attention in MultiHeadAttention. Only works for model_type=unet",
  256. )
  257. parser.set_defaults(disable_packed_qkv=False)
  258. parser.add_argument(
  259. "--disable_bias_add",
  260. required=False,
  261. action="store_true",
  262. help="not fuse BiasAdd. Only works for model_type=unet",
  263. )
  264. parser.set_defaults(disable_bias_add=False)
  265. parser.add_argument(
  266. "--disable_bias_splitgelu",
  267. required=False,
  268. action="store_true",
  269. help="not fuse BiasSplitGelu. Only works for model_type=unet",
  270. )
  271. parser.set_defaults(disable_bias_splitgelu=False)
  272. parser.add_argument(
  273. "--disable_nhwc_conv",
  274. required=False,
  275. action="store_true",
  276. help="Do not use NhwcConv. Only works for model_type=unet or vae",
  277. )
  278. parser.set_defaults(disable_nhwc_conv=False)
  279. parser.add_argument(
  280. "--use_group_norm_channels_first",
  281. required=False,
  282. action="store_true",
  283. help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae",
  284. )
  285. parser.set_defaults(use_group_norm_channels_first=False)
  286. parser.add_argument(
  287. "--disable_rotary_embeddings",
  288. required=False,
  289. action="store_true",
  290. help="Do not fuse rotary embeddings into RotaryEmbedding op",
  291. )