| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- from __future__ import annotations
- import onnx
- from ..onnx_model import ONNXModel
- from .fusion import Fusion
- class FusionGelu(Fusion):
- def __init__(self, model: ONNXModel):
- super().__init__(model, "Gelu", "Erf")
- def fuse(
- self,
- erf_node: onnx.NodeProto,
- input_name_to_nodes: dict[str, list[onnx.NodeProto]],
- output_name_to_node: dict[str, onnx.NodeProto],
- ):
- """
- Interface function that tries to fuse a node sequence containing an Erf node into a single
- Gelu node.
- """
- if (
- self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node)
- or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node)
- or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
- ):
- self.model.set_opset_import("com.microsoft", 1)
- def fuse_1(
- self,
- erf_node: onnx.NodeProto,
- input_name_to_nodes: dict[str, list[onnx.NodeProto]],
- output_name_to_node: dict[str, onnx.NodeProto],
- ) -> bool:
- """
- This pattern is from PyTorch model
- Fuse Gelu with Erf into one node:
- Pattern 1:
- +-------Mul(0.5)---------------------+
- | |
- | v
- [root] --> Div -----> Erf --> Add --> Mul -->
- (B=1.4142...) (1)
- Pattern 2:
- +------------------------------------+
- | |
- | v
- [root] --> Div -----> Erf --> Add --> Mul -->Mul -->
- (B=1.4142...) (1) (0.5)
- Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
- """
- if erf_node.output[0] not in input_name_to_nodes:
- return False
- children = input_name_to_nodes[erf_node.output[0]]
- if len(children) != 1 or children[0].op_type != "Add":
- return False
- add_after_erf = children[0]
- if not self.has_constant_input(add_after_erf, 1):
- return False
- if add_after_erf.output[0] not in input_name_to_nodes:
- return False
- children = input_name_to_nodes[add_after_erf.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return False
- mul_after_erf = children[0]
- div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
- if div is None:
- return False
- if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
- return False
- subgraph_input = div.input[0]
- another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
- if subgraph_input == mul_after_erf.input[another]: # pattern 2
- children = input_name_to_nodes[mul_after_erf.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return False
- mul_half = children[0]
- if not self.has_constant_input(mul_half, 0.5):
- return False
- subgraph_output = mul_half.output[0]
- else: # pattern 1
- mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
- if mul_half is None:
- return False
- if not self.has_constant_input(mul_half, 0.5):
- return False
- if subgraph_input not in mul_half.input:
- return False
- subgraph_output = mul_after_erf.output[0]
- subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
- if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
- return False
- self.nodes_to_remove.extend(subgraph_nodes)
- fused_node = onnx.helper.make_node(
- "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
- )
- fused_node.domain = "com.microsoft"
- self.nodes_to_add.append(fused_node)
- return True
- def fuse_2(
- self,
- erf_node: onnx.NodeProto,
- input_name_to_nodes: dict[str, list[onnx.NodeProto]],
- output_name_to_node: dict[str, onnx.NodeProto],
- ) -> bool:
- """
- This pattern is from Keras model
- Fuse Gelu with Erf into one node:
- +------------------------------------------+
- | |
- | v
- [root] --> Div -----> Erf --> Add --> Mul -->Mul
- (B=1.4142...) (A=1) (A=0.5)
- Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
- """
- if erf_node.output[0] not in input_name_to_nodes:
- return False
- children = input_name_to_nodes[erf_node.output[0]]
- if len(children) != 1 or children[0].op_type != "Add":
- return False
- add_after_erf = children[0]
- if not self.has_constant_input(add_after_erf, 1):
- return False
- if add_after_erf.output[0] not in input_name_to_nodes:
- return False
- children = input_name_to_nodes[add_after_erf.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return False
- mul_after_erf = children[0]
- if not self.has_constant_input(mul_after_erf, 0.5):
- return False
- if mul_after_erf.output[0] not in input_name_to_nodes:
- return False
- children = input_name_to_nodes[mul_after_erf.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return False
- mul = children[0]
- div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
- if div is None:
- return False
- sqrt_node = None
- if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
- sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node)
- if sqrt_node is None:
- return False
- if not self.has_constant_input(sqrt_node, 2.0):
- return False
- subgraph_input = div.input[0]
- if subgraph_input not in mul.input:
- return False
- subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
- if sqrt_node:
- subgraph_nodes.append(sqrt_node)
- if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node):
- return False
- self.nodes_to_remove.extend(subgraph_nodes)
- fused_node = onnx.helper.make_node(
- "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
- )
- fused_node.domain = "com.microsoft"
- self.nodes_to_add.append(fused_node)
- return True
- def fuse_3(
- self,
- erf_node: onnx.NodeProto,
- input_name_to_nodes: dict[str, list[onnx.NodeProto]],
- output_name_to_node: dict[str, onnx.NodeProto],
- ) -> bool:
- """
- This pattern is from TensorFlow model
- Fuse Gelu with Erf into one node:
- +----------------------------------------------+
- | |
- | v
- [root] --> Mul -----> Erf --> Add --> Mul -->Mul
- (A=0.7071067690849304) (B=1) (B=0.5)
- Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
- """
- if erf_node.output[0] not in input_name_to_nodes:
- return False
- children = input_name_to_nodes[erf_node.output[0]]
- if len(children) != 1 or children[0].op_type != "Add":
- return False
- add_after_erf = children[0]
- if not self.has_constant_input(add_after_erf, 1):
- return False
- if add_after_erf.output[0] not in input_name_to_nodes:
- return False
- children = input_name_to_nodes[add_after_erf.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return False
- mul_half = children[0]
- if not self.has_constant_input(mul_half, 0.5):
- return False
- first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node)
- if first_mul is None:
- return False
- i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
- if i < 0:
- return False
- root_input_index = 1 - i
- subgraph_input = first_mul.input[root_input_index]
- if mul_half.output[0] not in input_name_to_nodes:
- return False
- children = input_name_to_nodes[mul_half.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return False
- last_mul = children[0]
- if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
- return False
- subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
- if not self.is_safe_to_fuse_nodes(
- subgraph_nodes,
- [last_mul.output[0]],
- input_name_to_nodes,
- output_name_to_node,
- ):
- return False
- self.nodes_to_remove.extend(subgraph_nodes)
- fused_node = onnx.helper.make_node(
- "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
- )
- fused_node.domain = "com.microsoft"
- self.nodes_to_add.append(fused_node)
- return True
|