onnx_model_clip.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_attention_clip import FusionAttentionClip
  7. from onnx import ModelProto
  8. from onnx_model_bert import BertOnnxModel
  9. logger = getLogger(__name__)
  10. class ClipOnnxModel(BertOnnxModel):
  11. def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
  12. super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
  13. self.clip_attention_fusion = FusionAttentionClip(self, self.hidden_size, self.num_heads)
  14. def get_fused_operator_statistics(self):
  15. """
  16. Returns node count of fused operators.
  17. """
  18. op_count = {}
  19. ops = [
  20. "Attention",
  21. "FastGelu",
  22. "Gelu",
  23. "LayerNormalization",
  24. "QuickGelu",
  25. "BiasGelu",
  26. "SkipLayerNormalization",
  27. ]
  28. for op in ops:
  29. nodes = self.get_nodes_by_op_type(op)
  30. op_count[op] = len(nodes)
  31. logger.info(f"Optimized operators:{op_count}")
  32. return op_count
  33. def fuse_attention(self):
  34. self.clip_attention_fusion.apply()