onnx_model_mmdit.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_layernorm import FusionLayerNormalization
  7. from fusion_mha_mmdit import FusionMultiHeadAttentionMMDit
  8. from fusion_options import FusionOptions
  9. from import_utils import is_installed
  10. from onnx import ModelProto
  11. from onnx_model_bert import BertOnnxModel
  12. logger = logging.getLogger(__name__)
  13. class MmditOnnxModel(BertOnnxModel):
  14. def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
  15. """Initialize Multimodal Diffusion Transformer (MMDiT) ONNX Model.
  16. Args:
  17. model (ModelProto): the ONNX model
  18. num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
  19. hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
  20. """
  21. assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
  22. super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
  23. def postprocess(self):
  24. self.prune_graph()
  25. self.remove_unused_constant()
  26. def fuse_layer_norm(self):
  27. layernorm_support_broadcast = True
  28. logger.warning(
  29. "The optimized model requires LayerNormalization with broadcast support. "
  30. "Please use onnxruntime-gpu>=1.21 for inference."
  31. )
  32. fusion = FusionLayerNormalization(
  33. self, check_constant_and_dimension=not layernorm_support_broadcast, force=True
  34. )
  35. fusion.apply()
  36. def fuse_multi_head_attention(self):
  37. fusion = FusionMultiHeadAttentionMMDit(self)
  38. fusion.apply()
  39. def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
  40. assert not add_dynamic_axes
  41. if is_installed("tqdm"):
  42. import tqdm # noqa: PLC0415
  43. from tqdm.contrib.logging import logging_redirect_tqdm # noqa: PLC0415
  44. with logging_redirect_tqdm():
  45. steps = 5
  46. progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion")
  47. self._optimize(options, progress_bar)
  48. else:
  49. logger.info("tqdm is not installed. Run optimization without progress bar")
  50. self._optimize(options, None)
  51. def _optimize(self, options: FusionOptions | None = None, progress_bar=None):
  52. if (options is not None) and not options.enable_shape_inference:
  53. self.disable_shape_inference()
  54. # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
  55. self.utils.remove_useless_cast_nodes()
  56. if progress_bar:
  57. progress_bar.update(1)
  58. if (options is None) or options.enable_layer_norm:
  59. self.fuse_layer_norm()
  60. self.fuse_simplified_layer_norm()
  61. if progress_bar:
  62. progress_bar.update(1)
  63. if (options is None) or options.enable_gelu:
  64. self.fuse_gelu()
  65. if progress_bar:
  66. progress_bar.update(1)
  67. if (options is None) or options.enable_attention:
  68. self.fuse_multi_head_attention()
  69. if progress_bar:
  70. progress_bar.update(1)
  71. self.postprocess()
  72. if progress_bar:
  73. progress_bar.update(1)
  74. logger.info(f"opset version: {self.get_opset_version()}")
  75. def get_fused_operator_statistics(self):
  76. """
  77. Returns node count of fused operators.
  78. """
  79. op_count = {}
  80. ops = [
  81. "FastGelu",
  82. "MultiHeadAttention",
  83. "LayerNormalization",
  84. "SimplifiedLayerNormalization",
  85. ]
  86. for op in ops:
  87. nodes = self.get_nodes_by_op_type(op)
  88. op_count[op] = len(nodes)
  89. logger.info(f"Optimized operators:{op_count}")
  90. return op_count