onnx_model_unet.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_attention_unet import FusionAttentionUnet
  7. from fusion_bias_add import FusionBiasAdd
  8. from fusion_biassplitgelu import FusionBiasSplitGelu
  9. from fusion_group_norm import FusionGroupNorm
  10. from fusion_nhwc_conv import FusionNhwcConv
  11. from fusion_options import FusionOptions
  12. from fusion_skip_group_norm import FusionSkipGroupNorm
  13. from fusion_transpose import FusionInsertTranspose, FusionTranspose
  14. from import_utils import is_installed
  15. from onnx import ModelProto
  16. from onnx_model import OnnxModel
  17. from onnx_model_bert import BertOnnxModel
  18. logger = logging.getLogger(__name__)
  19. class UnetOnnxModel(BertOnnxModel):
  20. def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
  21. """Initialize UNet ONNX Model.
  22. Args:
  23. model (ModelProto): the ONNX model
  24. num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
  25. hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
  26. """
  27. assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
  28. super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
  29. def preprocess(self):
  30. self.remove_useless_div()
  31. def postprocess(self):
  32. self.prune_graph()
  33. self.remove_unused_constant()
  34. def remove_useless_div(self):
  35. """Remove Div by 1"""
  36. div_nodes = [node for node in self.nodes() if node.op_type == "Div"]
  37. nodes_to_remove = []
  38. for div in div_nodes:
  39. if self.find_constant_input(div, 1.0) == 1:
  40. nodes_to_remove.append(div)
  41. for node in nodes_to_remove:
  42. self.replace_input_of_all_nodes(node.output[0], node.input[0])
  43. if nodes_to_remove:
  44. self.remove_nodes(nodes_to_remove)
  45. logger.info("Removed %d Div nodes", len(nodes_to_remove))
  46. def convert_conv_to_nhwc(self):
  47. # Transpose weights in offline might help since ORT does not apply constant-folding on Transpose nodes.
  48. conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=True)
  49. conv_to_nhwc_conv.apply()
  50. def merge_adjacent_transpose(self):
  51. fusion_transpose = FusionTranspose(self)
  52. fusion_transpose.apply()
  53. remove_count = 0
  54. nodes = self.get_nodes_by_op_type("Transpose")
  55. for node in nodes:
  56. permutation = OnnxModel.get_node_attribute(node, "perm")
  57. assert isinstance(permutation, list)
  58. if permutation != list(range(len(permutation))):
  59. continue
  60. assert not (
  61. self.find_graph_output(node.output[0])
  62. or self.find_graph_input(node.input[0])
  63. or self.find_graph_output(node.input[0])
  64. )
  65. # Let all children nodes skip current Transpose node and link to its parent
  66. # Note that we cannot update parent node output since parent node might have more than one children.
  67. self.replace_input_of_all_nodes(node.output[0], node.input[0])
  68. self.remove_node(node)
  69. remove_count += 1
  70. total = len(fusion_transpose.nodes_to_remove) + remove_count
  71. if total:
  72. logger.info("Removed %d Transpose nodes", total)
  73. def fuse_multi_head_attention(self, options: FusionOptions | None = None):
  74. # Self Attention
  75. enable_packed_qkv = (options is None) or options.enable_packed_qkv
  76. self_attention_fusion = FusionAttentionUnet(
  77. self,
  78. self.hidden_size,
  79. self.num_heads,
  80. is_cross_attention=False,
  81. enable_packed_qkv=enable_packed_qkv,
  82. enable_packed_kv=False,
  83. )
  84. self_attention_fusion.apply()
  85. # Cross Attention
  86. enable_packed_kv = (options is None) or options.enable_packed_kv
  87. cross_attention_fusion = FusionAttentionUnet(
  88. self,
  89. self.hidden_size,
  90. self.num_heads,
  91. is_cross_attention=True,
  92. enable_packed_qkv=False,
  93. enable_packed_kv=enable_packed_kv,
  94. )
  95. cross_attention_fusion.apply()
  96. def fuse_bias_add(self):
  97. fusion = FusionBiasAdd(self)
  98. fusion.apply()
  99. def optimize(self, options: FusionOptions | None = None):
  100. if is_installed("tqdm"):
  101. import tqdm # noqa: PLC0415
  102. from tqdm.contrib.logging import logging_redirect_tqdm # noqa: PLC0415
  103. with logging_redirect_tqdm():
  104. steps = 18
  105. progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion")
  106. self._optimize(options, progress_bar)
  107. else:
  108. logger.info("tqdm is not installed. Run optimization without progress bar")
  109. self._optimize(options, None)
  110. def _optimize(self, options: FusionOptions | None = None, progress_bar=None):
  111. if (options is not None) and not options.enable_shape_inference:
  112. self.disable_shape_inference()
  113. self.utils.remove_identity_nodes()
  114. if progress_bar:
  115. progress_bar.update(1)
  116. # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
  117. self.utils.remove_useless_cast_nodes()
  118. if progress_bar:
  119. progress_bar.update(1)
  120. if (options is None) or options.enable_layer_norm:
  121. self.fuse_layer_norm()
  122. if progress_bar:
  123. progress_bar.update(1)
  124. if (options is None) or options.enable_gelu:
  125. self.fuse_gelu()
  126. if progress_bar:
  127. progress_bar.update(1)
  128. self.preprocess()
  129. if progress_bar:
  130. progress_bar.update(1)
  131. self.fuse_reshape()
  132. if progress_bar:
  133. progress_bar.update(1)
  134. if (options is None) or options.enable_group_norm:
  135. channels_last = (options is None) or options.group_norm_channels_last
  136. group_norm_fusion = FusionGroupNorm(self, channels_last)
  137. group_norm_fusion.apply()
  138. insert_transpose_fusion = FusionInsertTranspose(self)
  139. insert_transpose_fusion.apply()
  140. if progress_bar:
  141. progress_bar.update(1)
  142. if (options is None) or options.enable_bias_splitgelu:
  143. bias_split_gelu_fusion = FusionBiasSplitGelu(self)
  144. bias_split_gelu_fusion.apply()
  145. if progress_bar:
  146. progress_bar.update(1)
  147. if (options is None) or options.enable_attention:
  148. # self.save_model_to_file("before_mha.onnx")
  149. self.fuse_multi_head_attention(options)
  150. if progress_bar:
  151. progress_bar.update(1)
  152. if (options is None) or options.enable_skip_layer_norm:
  153. self.fuse_skip_layer_norm()
  154. if progress_bar:
  155. progress_bar.update(1)
  156. self.fuse_shape()
  157. if progress_bar:
  158. progress_bar.update(1)
  159. # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
  160. self.utils.remove_useless_reshape_nodes()
  161. if progress_bar:
  162. progress_bar.update(1)
  163. if (options is None) or options.enable_skip_group_norm:
  164. skip_group_norm_fusion = FusionSkipGroupNorm(self)
  165. skip_group_norm_fusion.apply()
  166. if progress_bar:
  167. progress_bar.update(1)
  168. if (options is None) or options.enable_bias_skip_layer_norm:
  169. # Fuse SkipLayerNormalization and Add Bias before it.
  170. self.fuse_add_bias_skip_layer_norm()
  171. if progress_bar:
  172. progress_bar.update(1)
  173. if options is not None and options.enable_gelu_approximation:
  174. self.gelu_approximation()
  175. if progress_bar:
  176. progress_bar.update(1)
  177. if options is None or options.enable_nhwc_conv:
  178. self.convert_conv_to_nhwc()
  179. self.merge_adjacent_transpose()
  180. if progress_bar:
  181. progress_bar.update(1)
  182. if options is not None and options.enable_bias_add:
  183. self.fuse_bias_add()
  184. if progress_bar:
  185. progress_bar.update(1)
  186. self.postprocess()
  187. if progress_bar:
  188. progress_bar.update(1)
  189. logger.info(f"opset version: {self.get_opset_version()}")
  190. def get_fused_operator_statistics(self):
  191. """
  192. Returns node count of fused operators.
  193. """
  194. op_count = {}
  195. ops = [
  196. "Attention",
  197. "MultiHeadAttention",
  198. "LayerNormalization",
  199. "SkipLayerNormalization",
  200. "BiasSplitGelu",
  201. "GroupNorm",
  202. "SkipGroupNorm",
  203. "NhwcConv",
  204. "BiasAdd",
  205. ]
  206. for op in ops:
  207. nodes = self.get_nodes_by_op_type(op)
  208. op_count[op] = len(nodes)
  209. logger.info(f"Optimized operators:{op_count}")
  210. return op_count