fusion_base.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from collections import defaultdict
  6. from collections.abc import Sequence
  7. from logging import getLogger
  8. from typing import Any
  9. import numpy as np
  10. from onnx import NodeProto, TensorProto, helper
  11. from onnx_model import OnnxModel
  12. logger = getLogger(__name__)
  13. class Fusion:
  14. """
  15. Base class for Graph Fusion
  16. """
  17. def __init__(
  18. self,
  19. model: OnnxModel,
  20. fused_op_type: str,
  21. search_op_types: str | list[str],
  22. description: str = "",
  23. ):
  24. self.search_op_types: list[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types
  25. self.fused_op_type: str = fused_op_type
  26. self.description: str = f"{fused_op_type}({description})" if description else fused_op_type
  27. self.model: OnnxModel = model
  28. self.nodes_to_remove: list = []
  29. self.nodes_to_add: list = []
  30. self.prune_graph: bool = False
  31. self.node_name_to_graph_name: dict = {}
  32. self.this_graph_name: str | None = None
  33. # It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter.
  34. self.fused_count: defaultdict = defaultdict(int)
  35. def increase_counter(self, fused_op_name: str):
  36. """
  37. Increase counter of a fused operator.
  38. """
  39. self.fused_count[fused_op_name] += 1
  40. def fuse(
  41. self,
  42. node: NodeProto,
  43. input_name_to_nodes: dict[str, list[NodeProto]],
  44. output_name_to_node: dict[str, NodeProto],
  45. ):
  46. """Interface for fusion that starts from a node"""
  47. raise NotImplementedError
  48. def apply(self):
  49. """
  50. Apply graph fusion on the whole model graph.
  51. It searched nodes of given operators, and start fusion on each of those nodes.
  52. """
  53. logger.debug(f"start {self.description} fusion...")
  54. input_name_to_nodes = self.model.input_name_to_nodes()
  55. output_name_to_node = self.model.output_name_to_node()
  56. # This assumes that two search ops will not be fused at same time!
  57. for search_op_type in self.search_op_types:
  58. for node in self.model.get_nodes_by_op_type(search_op_type):
  59. graph = self.model.get_graph_by_node(node)
  60. if graph is None:
  61. raise Exception("Can not find node in any graph")
  62. self.this_graph_name = graph.name
  63. self.fuse(node, input_name_to_nodes, output_name_to_node)
  64. op_list = [node.op_type for node in self.nodes_to_add]
  65. if self.fused_count:
  66. for key, value in self.fused_count.items():
  67. if value:
  68. logger.info(f"Fused {key}: {value}")
  69. else:
  70. count = op_list.count(self.fused_op_type)
  71. if count > 0:
  72. logger.info(f"Fused {self.description}: {count}")
  73. self.model.remove_nodes(self.nodes_to_remove)
  74. self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
  75. if self.prune_graph:
  76. self.model.prune_graph()
  77. elif self.nodes_to_remove or self.nodes_to_add:
  78. self.model.update_graph()
  79. def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
  80. if raw:
  81. if not isinstance(vals, np.ndarray):
  82. np_type = helper.tensor_dtype_to_np_dtype(data_type)
  83. bytes = np.array(vals, dtype=np_type).tobytes()
  84. else:
  85. bytes = vals.tobytes()
  86. tensor = helper.make_tensor(
  87. name=name,
  88. data_type=data_type,
  89. dims=dims,
  90. vals=bytes,
  91. raw=True,
  92. )
  93. else:
  94. tensor = helper.make_tensor(
  95. name=name,
  96. data_type=data_type,
  97. dims=dims,
  98. vals=vals,
  99. raw=False,
  100. )
  101. self.model.add_initializer(tensor, self.this_graph_name)
  102. return tensor
  103. def remove_initializer(self, tensor: TensorProto):
  104. self.model.remove_initializer(tensor)
  105. def add_nodes_to_remove(self, nodes: list[NodeProto]):
  106. # Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths).
  107. # When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B
  108. # is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are
  109. # iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first.
  110. # Since path A's shared nodes are removed, path B's shared nodes are not removed because they
  111. # were previously removed for path A. This causes an error to print in remove_node that a node
  112. # has failed to be removed.
  113. #
  114. # To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`.
  115. # We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could
  116. # be scenarios where the nodes need to be removed in a specific order and converting to a set would
  117. # lose this order.
  118. for node in nodes:
  119. if node not in self.nodes_to_remove:
  120. self.nodes_to_remove.append(node)
  121. def add_nodes_to_remove_with_nodes_to_keep(self, nodes: list[NodeProto], nodes_to_keep: list[NodeProto]):
  122. for node in nodes:
  123. if node not in self.nodes_to_remove and node not in nodes_to_keep:
  124. self.nodes_to_remove.append(node)