| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- from argparse import ArgumentParser
- from enum import Enum
- class AttentionMaskFormat:
- # Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance.
- MaskIndexEnd = 0
- # For experiment only. Do not use it in production.
- MaskIndexEndAndStart = 1
- # Raw attention mask with 0 means padding (or no attention) and 1 otherwise.
- AttentionMask = 2
- # No attention mask
- NoMask = 3
- class AttentionOpType(Enum):
- Attention = "Attention"
- MultiHeadAttention = "MultiHeadAttention"
- GroupQueryAttention = "GroupQueryAttention"
- PagedAttention = "PagedAttention"
- def __str__(self):
- return self.value
- # Override __eq__ to return string comparison
- def __hash__(self):
- return hash(self.value)
- def __eq__(self, other):
- return other.value == self.value
- class FusionOptions:
- """Options of fusion in graph optimization"""
- def __init__(self, model_type):
- self.enable_gelu = True
- self.enable_layer_norm = True
- self.enable_attention = True
- self.enable_rotary_embeddings = True
- # Use MultiHeadAttention instead of Attention operator. The difference:
- # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is
- # merged into one.
- # (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention.
- self.use_multi_head_attention = False
- self.disable_multi_head_attention_bias = False
- self.enable_skip_layer_norm = True
- self.enable_embed_layer_norm = True
- self.enable_bias_skip_layer_norm = True
- self.enable_bias_gelu = True
- self.enable_gelu_approximation = False
- self.enable_qordered_matmul = True
- self.enable_shape_inference = True
- self.enable_gemm_fast_gelu = False
- self.group_norm_channels_last = True
- if model_type == "clip":
- self.enable_embed_layer_norm = False
- # Set default to sequence length for BERT model to use fused attention to speed up.
- # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
- self.attention_mask_format = AttentionMaskFormat.AttentionMask
- if model_type == "bert":
- self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
- elif model_type == "vit":
- self.attention_mask_format = AttentionMaskFormat.NoMask
- self.attention_op_type = None
- # options for stable diffusion
- if model_type in ["unet", "vae", "clip"]:
- self.enable_nhwc_conv = True
- self.enable_group_norm = True
- self.enable_skip_group_norm = True
- self.enable_bias_splitgelu = True
- self.enable_packed_qkv = True
- self.enable_packed_kv = True
- self.enable_bias_add = True
- def use_raw_attention_mask(self, use_raw_mask=True):
- if use_raw_mask:
- self.attention_mask_format = AttentionMaskFormat.AttentionMask
- else:
- self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
- def disable_attention_mask(self):
- self.attention_mask_format = AttentionMaskFormat.NoMask
- def set_attention_op_type(self, attn_op_type: AttentionOpType):
- self.attention_op_type = attn_op_type
- @staticmethod
- def parse(args):
- options = FusionOptions(args.model_type)
- if args.disable_gelu:
- options.enable_gelu = False
- if args.disable_layer_norm:
- options.enable_layer_norm = False
- if args.disable_rotary_embeddings:
- options.enable_rotary_embeddings = False
- if args.disable_attention:
- options.enable_attention = False
- if args.use_multi_head_attention:
- options.use_multi_head_attention = True
- if args.disable_skip_layer_norm:
- options.enable_skip_layer_norm = False
- if args.disable_embed_layer_norm:
- options.enable_embed_layer_norm = False
- if args.disable_bias_skip_layer_norm:
- options.enable_bias_skip_layer_norm = False
- if args.disable_bias_gelu:
- options.enable_bias_gelu = False
- if args.enable_gelu_approximation:
- options.enable_gelu_approximation = True
- if args.disable_shape_inference:
- options.enable_shape_inference = False
- if args.enable_gemm_fast_gelu:
- options.enable_gemm_fast_gelu = True
- if args.use_mask_index:
- options.use_raw_attention_mask(False)
- if args.use_raw_attention_mask:
- options.use_raw_attention_mask(True)
- if args.no_attention_mask:
- options.disable_attention_mask()
- if args.model_type in ["unet", "vae", "clip"]:
- if args.use_group_norm_channels_first:
- options.group_norm_channels_last = False
- if args.disable_nhwc_conv:
- options.enable_nhwc_conv = False
- if args.disable_group_norm:
- options.enable_group_norm = False
- if args.disable_skip_group_norm:
- options.enable_skip_group_norm = False
- if args.disable_bias_splitgelu:
- options.enable_bias_splitgelu = False
- if args.disable_packed_qkv:
- options.enable_packed_qkv = False
- if args.disable_packed_kv:
- options.enable_packed_kv = False
- if args.disable_bias_add:
- options.enable_bias_add = False
- return options
- @staticmethod
- def add_arguments(parser: ArgumentParser):
- parser.add_argument(
- "--disable_attention",
- required=False,
- action="store_true",
- help="disable Attention fusion",
- )
- parser.set_defaults(disable_attention=False)
- parser.add_argument(
- "--disable_skip_layer_norm",
- required=False,
- action="store_true",
- help="disable SkipLayerNormalization fusion",
- )
- parser.set_defaults(disable_skip_layer_norm=False)
- parser.add_argument(
- "--disable_embed_layer_norm",
- required=False,
- action="store_true",
- help="disable EmbedLayerNormalization fusion",
- )
- parser.set_defaults(disable_embed_layer_norm=False)
- parser.add_argument(
- "--disable_bias_skip_layer_norm",
- required=False,
- action="store_true",
- help="disable Add Bias and SkipLayerNormalization fusion",
- )
- parser.set_defaults(disable_bias_skip_layer_norm=False)
- parser.add_argument(
- "--disable_bias_gelu",
- required=False,
- action="store_true",
- help="disable Add Bias and Gelu/FastGelu fusion",
- )
- parser.set_defaults(disable_bias_gelu=False)
- parser.add_argument(
- "--disable_layer_norm",
- required=False,
- action="store_true",
- help="disable LayerNormalization fusion",
- )
- parser.set_defaults(disable_layer_norm=False)
- parser.add_argument(
- "--disable_gelu",
- required=False,
- action="store_true",
- help="disable Gelu fusion",
- )
- parser.set_defaults(disable_gelu=False)
- parser.add_argument(
- "--enable_gelu_approximation",
- required=False,
- action="store_true",
- help="enable Gelu/BiasGelu to FastGelu conversion",
- )
- parser.set_defaults(enable_gelu_approximation=False)
- parser.add_argument(
- "--disable_shape_inference",
- required=False,
- action="store_true",
- help="disable symbolic shape inference",
- )
- parser.set_defaults(disable_shape_inference=False)
- parser.add_argument(
- "--enable_gemm_fast_gelu",
- required=False,
- action="store_true",
- help="enable GemmfastGelu fusion",
- )
- parser.set_defaults(enable_gemm_fast_gelu=False)
- parser.add_argument(
- "--use_mask_index",
- required=False,
- action="store_true",
- help="use mask index to activate fused attention to speed up. It requires right-side padding!",
- )
- parser.set_defaults(use_mask_index=False)
- parser.add_argument(
- "--use_raw_attention_mask",
- required=False,
- action="store_true",
- help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.",
- )
- parser.set_defaults(use_raw_attention_mask=False)
- parser.add_argument(
- "--no_attention_mask",
- required=False,
- action="store_true",
- help="no attention mask. Only works for model_type=bert",
- )
- parser.set_defaults(no_attention_mask=False)
- parser.add_argument(
- "--use_multi_head_attention",
- required=False,
- action="store_true",
- help="Use MultiHeadAttention instead of Attention operator for testing purpose. "
- "Note that MultiHeadAttention might be slower than Attention when qkv are not packed. ",
- )
- parser.set_defaults(use_multi_head_attention=False)
- parser.add_argument(
- "--disable_group_norm",
- required=False,
- action="store_true",
- help="not fuse GroupNorm. Only works for model_type=unet or vae",
- )
- parser.set_defaults(disable_group_norm=False)
- parser.add_argument(
- "--disable_skip_group_norm",
- required=False,
- action="store_true",
- help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae",
- )
- parser.set_defaults(disable_skip_group_norm=False)
- parser.add_argument(
- "--disable_packed_kv",
- required=False,
- action="store_true",
- help="not use packed kv for cross attention in MultiHeadAttention. Only works for model_type=unet",
- )
- parser.set_defaults(disable_packed_kv=False)
- parser.add_argument(
- "--disable_packed_qkv",
- required=False,
- action="store_true",
- help="not use packed qkv for self attention in MultiHeadAttention. Only works for model_type=unet",
- )
- parser.set_defaults(disable_packed_qkv=False)
- parser.add_argument(
- "--disable_bias_add",
- required=False,
- action="store_true",
- help="not fuse BiasAdd. Only works for model_type=unet",
- )
- parser.set_defaults(disable_bias_add=False)
- parser.add_argument(
- "--disable_bias_splitgelu",
- required=False,
- action="store_true",
- help="not fuse BiasSplitGelu. Only works for model_type=unet",
- )
- parser.set_defaults(disable_bias_splitgelu=False)
- parser.add_argument(
- "--disable_nhwc_conv",
- required=False,
- action="store_true",
- help="Do not use NhwcConv. Only works for model_type=unet or vae",
- )
- parser.set_defaults(disable_nhwc_conv=False)
- parser.add_argument(
- "--use_group_norm_channels_first",
- required=False,
- action="store_true",
- help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae",
- )
- parser.set_defaults(use_group_norm_channels_first=False)
- parser.add_argument(
- "--disable_rotary_embeddings",
- required=False,
- action="store_true",
- help="Do not fuse rotary embeddings into RotaryEmbedding op",
- )
|