onnx_model_vae.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_attention_vae import FusionAttentionVae
  7. from fusion_options import FusionOptions
  8. from onnx import ModelProto
  9. from onnx_model_unet import UnetOnnxModel
  10. logger = getLogger(__name__)
  11. class VaeOnnxModel(UnetOnnxModel):
  12. def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
  13. assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
  14. super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
  15. def fuse_multi_head_attention(self, options: FusionOptions | None = None):
  16. # Self Attention
  17. self_attention_fusion = FusionAttentionVae(self, self.hidden_size, self.num_heads)
  18. self_attention_fusion.apply()
  19. def get_fused_operator_statistics(self):
  20. """
  21. Returns node count of fused operators.
  22. """
  23. op_count = {}
  24. ops = [
  25. "Attention",
  26. "GroupNorm",
  27. "SkipGroupNorm",
  28. "NhwcConv",
  29. ]
  30. for op in ops:
  31. nodes = self.get_nodes_by_op_type(op)
  32. op_count[op] = len(nodes)
  33. logger.info(f"Optimized operators:{op_count}")
  34. return op_count