onnx_model_sam2.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_attention_sam2 import FusionMultiHeadAttentionSam2
  7. from fusion_layernorm import FusionLayerNormalizationNCHW
  8. from fusion_options import FusionOptions
  9. from import_utils import is_installed
  10. from onnx import ModelProto
  11. from onnx_model_bert import BertOnnxModel
  12. logger = logging.getLogger(__name__)
  13. class Sam2OnnxModel(BertOnnxModel):
  14. def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
  15. """Initialize SAM2 ONNX Model.
  16. Args:
  17. model (ModelProto): the ONNX model
  18. num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
  19. hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
  20. """
  21. assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
  22. super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
  23. def postprocess(self):
  24. self.prune_graph()
  25. self.remove_unused_constant()
  26. def fuse_layer_norm(self):
  27. super().fuse_layer_norm()
  28. fusion = FusionLayerNormalizationNCHW(self)
  29. fusion.apply()
  30. def fuse_multi_head_attention(self, options: FusionOptions | None = None):
  31. mha_fusion = FusionMultiHeadAttentionSam2(self, self.hidden_size, self.num_heads)
  32. mha_fusion.apply()
  33. def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
  34. if is_installed("tqdm"):
  35. import tqdm # noqa: PLC0415
  36. from tqdm.contrib.logging import logging_redirect_tqdm # noqa: PLC0415
  37. with logging_redirect_tqdm():
  38. steps = 12
  39. progress_bar = tqdm.tqdm(range(steps), initial=0, desc="sam2 fusion")
  40. self._optimize(options, progress_bar)
  41. else:
  42. logger.info("tqdm is not installed. Run optimization without progress bar")
  43. self._optimize(options, None)
  44. def _optimize(self, options: FusionOptions | None = None, progress_bar=None):
  45. if (options is not None) and not options.enable_shape_inference:
  46. self.disable_shape_inference()
  47. self.utils.remove_identity_nodes()
  48. if progress_bar:
  49. progress_bar.update(1)
  50. # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
  51. self.utils.remove_useless_cast_nodes()
  52. if progress_bar:
  53. progress_bar.update(1)
  54. if (options is None) or options.enable_layer_norm:
  55. self.fuse_layer_norm()
  56. if progress_bar:
  57. progress_bar.update(1)
  58. if (options is None) or options.enable_gelu:
  59. self.fuse_gelu()
  60. if progress_bar:
  61. progress_bar.update(1)
  62. self.fuse_reshape()
  63. if progress_bar:
  64. progress_bar.update(1)
  65. if (options is None) or options.enable_attention:
  66. self.fuse_multi_head_attention(options)
  67. if progress_bar:
  68. progress_bar.update(1)
  69. if (options is None) or options.enable_skip_layer_norm:
  70. self.fuse_skip_layer_norm()
  71. if progress_bar:
  72. progress_bar.update(1)
  73. self.fuse_shape()
  74. if progress_bar:
  75. progress_bar.update(1)
  76. # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
  77. self.utils.remove_useless_reshape_nodes()
  78. if progress_bar:
  79. progress_bar.update(1)
  80. if (options is None) or options.enable_bias_skip_layer_norm:
  81. # Fuse SkipLayerNormalization and Add Bias before it.
  82. self.fuse_add_bias_skip_layer_norm()
  83. if progress_bar:
  84. progress_bar.update(1)
  85. if options is not None and options.enable_gelu_approximation:
  86. self.gelu_approximation()
  87. if progress_bar:
  88. progress_bar.update(1)
  89. self.postprocess()
  90. if progress_bar:
  91. progress_bar.update(1)
  92. logger.info(f"opset version: {self.get_opset_version()}")
  93. def get_fused_operator_statistics(self):
  94. """
  95. Returns node count of fused operators.
  96. """
  97. op_count = {}
  98. ops = [
  99. "MultiHeadAttention",
  100. "LayerNormalization",
  101. "SkipLayerNormalization",
  102. ]
  103. for op in ops:
  104. nodes = self.get_nodes_by_op_type(op)
  105. op_count[op] = len(nodes)
  106. logger.info(f"Optimized operators:{op_count}")
  107. return op_count