| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- from logging import getLogger
- from fusion_base import Fusion
- from onnx import helper
- from onnx_model import OnnxModel
- logger = getLogger(__name__)
- class FusionFastGelu(Fusion):
- def __init__(self, model: OnnxModel):
- super().__init__(model, "FastGelu", "Tanh")
- def fuse(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict):
- if self.fuse_1(tanh_node, input_name_to_nodes, output_name_to_node):
- return
- if self.fuse_2(tanh_node, input_name_to_nodes, output_name_to_node):
- return
- if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
- return
- if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node):
- return
- def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> bool | None:
- """
- Fuse Gelu with tanh into one node:
- +---------------------------+
- | |
- | v
- [root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul
- | (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
- | |
- +------> Mul(B=0.5)--------------------------------------------+
- Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
- """
- if tanh_node.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[tanh_node.output[0]]
- if len(children) != 1 or children[0].op_type != "Add":
- return
- add_after_tanh = children[0]
- if not self.model.has_constant_input(add_after_tanh, 1.0):
- return
- if add_after_tanh.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[add_after_tanh.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return
- mul_after_tanh = children[0]
- mul_half = self.model.match_parent(mul_after_tanh, "Mul", None, output_name_to_node)
- if mul_half is None:
- return
- i = self.model.find_constant_input(mul_half, 0.5)
- if i < 0:
- return
- root_input = mul_half.input[0 if i == 1 else 1]
- # root_node could be None when root_input is graph input
- root_node = self.model.get_parent(mul_half, 0 if i == 1 else 1, output_name_to_node)
- mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
- if mul_before_tanh is None:
- return
- i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
- if i < 0:
- return
- add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
- if add_before_tanh is None:
- return
- mul_after_pow = self.model.match_parent(
- add_before_tanh,
- "Mul",
- None,
- output_name_to_node,
- exclude=[root_node] if root_node else [],
- )
- if mul_after_pow is None:
- return
- i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
- if i < 0:
- return
- pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
- if pow is None:
- return
- if not self.model.has_constant_input(pow, 3.0):
- return
- if pow.input[0] != root_input:
- return
- subgraph_nodes = [
- mul_after_tanh,
- mul_half,
- add_after_tanh,
- tanh_node,
- mul_before_tanh,
- add_before_tanh,
- mul_after_pow,
- pow,
- ]
- if not self.model.is_safe_to_fuse_nodes(
- subgraph_nodes,
- [mul_after_tanh.output[0]],
- input_name_to_nodes,
- output_name_to_node,
- ):
- return
- self.nodes_to_remove.extend(subgraph_nodes)
- fused_node = helper.make_node(
- "FastGelu",
- inputs=[root_input],
- outputs=mul_after_tanh.output,
- name=self.model.create_node_name("FastGelu"),
- )
- fused_node.domain = "com.microsoft"
- self.nodes_to_add.append(fused_node)
- self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
- return True
- def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None:
- """
- This pattern is from Tensorflow model.
- Fuse Gelu with tanh into one node:
- +---------------------------+
- | |
- | v
- [root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul(B=0.5)-->Mul-->
- | (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
- | |
- +---------------------------------------------------------------------------+
- Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
- """
- if tanh_node.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[tanh_node.output[0]]
- if len(children) != 1 or children[0].op_type != "Add":
- return
- add_after_tanh = children[0]
- if not self.model.has_constant_input(add_after_tanh, 1.0):
- return
- if add_after_tanh.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[add_after_tanh.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return
- mul_half = children[0]
- i = self.model.find_constant_input(mul_half, 0.5)
- if i < 0:
- return
- if mul_half.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[mul_half.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return
- mul_after_mul_half = children[0]
- # root_node could be None when root_input is graph input
- root_node = self.model.get_parent(
- mul_after_mul_half,
- 0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1,
- output_name_to_node,
- )
- mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
- if mul_before_tanh is None:
- return
- i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
- if i < 0:
- return
- add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
- if add_before_tanh is None:
- return
- mul_after_pow = self.model.match_parent(
- add_before_tanh,
- "Mul",
- None,
- output_name_to_node,
- exclude=[root_node] if root_node else [],
- )
- if mul_after_pow is None:
- return
- i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
- if i < 0:
- return
- pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
- if pow is None:
- return
- if not self.model.has_constant_input(pow, 3.0):
- return
- root_input = mul_after_mul_half.input[0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1]
- if pow.input[0] != root_input:
- return
- subgraph_nodes = [
- mul_after_mul_half,
- mul_half,
- add_after_tanh,
- tanh_node,
- mul_before_tanh,
- add_before_tanh,
- mul_after_pow,
- pow,
- ]
- if not self.model.is_safe_to_fuse_nodes(
- subgraph_nodes,
- [mul_after_mul_half.output[0]],
- input_name_to_nodes,
- output_name_to_node,
- ):
- return
- self.nodes_to_remove.extend(subgraph_nodes)
- fused_node = helper.make_node(
- "FastGelu",
- inputs=[root_input],
- outputs=mul_after_mul_half.output,
- name=self.model.create_node_name("FastGelu"),
- )
- fused_node.domain = "com.microsoft"
- self.nodes_to_add.append(fused_node)
- self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
- return True
- def fuse_3(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None:
- """
- OpenAI's gelu implementation, also used in Megatron:
- Gelu(x) = x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1.0 + 0.044715 * x * x)))
- Fuse subgraph into a FastGelu node:
- +------------ Mul (B=0.79788456) -------------------+
- | |
- +-------------------------------+ |
- | | |
- | v v
- [root] --> Mul (B=0.044715) --> Mul --> Add(B=1) --> Mul --> Tanh --> Add(B=1) --> Mul-->
- | ^
- | |
- +-----------> Mul (B=0.5) --------------------------------------------------------+
- """
- if tanh_node.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[tanh_node.output[0]]
- if len(children) != 1 or children[0].op_type != "Add":
- return
- add_after_tanh = children[0]
- if not self.model.has_constant_input(add_after_tanh, 1.0):
- return
- if add_after_tanh.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[add_after_tanh.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return
- mul_last = children[0]
- mul_half = self.model.match_parent(mul_last, "Mul", None, output_name_to_node)
- if mul_half is None:
- return
- i = self.model.find_constant_input(mul_half, 0.5)
- if i < 0:
- return
- root_input = mul_half.input[0 if i == 1 else 1]
- mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
- if mul_before_tanh is None:
- return
- add_1 = self.model.match_parent(mul_before_tanh, "Add", None, output_name_to_node)
- if add_1 is None:
- return
- j = self.model.find_constant_input(add_1, 1.0)
- if j < 0:
- return
- mul_7978 = self.model.match_parent(mul_before_tanh, "Mul", None, output_name_to_node)
- if mul_7978 is None:
- return
- k = self.model.find_constant_input(mul_7978, 0.7978, delta=0.0001)
- if k < 0:
- return
- if mul_7978.input[0 if k == 1 else 1] != root_input:
- return
- mul_before_add_1 = self.model.match_parent(add_1, "Mul", 0 if j == 1 else 1, output_name_to_node)
- if mul_before_add_1 is None:
- return
- if mul_before_add_1.input[0] == root_input:
- another = 1
- elif mul_before_add_1.input[1] == root_input:
- another = 0
- else:
- return
- mul_0447 = self.model.match_parent(mul_before_add_1, "Mul", another, output_name_to_node)
- if mul_0447 is None:
- return
- m = self.model.find_constant_input(mul_0447, 0.0447, delta=0.0001)
- if m < 0:
- return
- if mul_0447.input[0 if m == 1 else 1] != root_input:
- return
- subgraph_nodes = [
- mul_0447,
- mul_before_add_1,
- add_1,
- mul_before_tanh,
- tanh_node,
- add_after_tanh,
- mul_7978,
- mul_half,
- mul_last,
- ]
- if not self.model.is_safe_to_fuse_nodes(
- subgraph_nodes,
- [mul_last.output[0]],
- input_name_to_nodes,
- output_name_to_node,
- ):
- return
- self.nodes_to_remove.extend(subgraph_nodes)
- fused_node = helper.make_node(
- "FastGelu",
- inputs=[root_input],
- outputs=mul_last.output,
- name=self.model.create_node_name("FastGelu"),
- )
- fused_node.domain = "com.microsoft"
- self.nodes_to_add.append(fused_node)
- self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
- return True
- def fuse_4(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None:
- """
- PyTorch's gelu implementation with tanh approximation:
- Gelu(x) = 0.5 * x * (1 + torch.tanh(0.7978845834732056 * (x + 0.044714998453855515 * x * x * x)))
- Fuse Gelu with tanh into one node:
- +-----------------+------------------+
- | | |
- | v v
- [root] ==> Mul --> Mul --> Mul -----> Add --> Mul --> Tanh --> Add -----> Mul --> Mul -->
- | (A=0.0447) (A=0.7978) (A=1) ^ (A=0.5)
- | |
- +-------------------------------------------------------------------------+
- Note that constant input for Add and Mul could be first or second input.
- """
- if tanh_node.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[tanh_node.output[0]]
- if len(children) != 1 or children[0].op_type != "Add":
- return
- add_after_tanh = children[0]
- if not self.model.has_constant_input(add_after_tanh, 1.0):
- return
- if add_after_tanh.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[add_after_tanh.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return
- mul_after_tanh = children[0]
- if mul_after_tanh.output[0] not in input_name_to_nodes:
- return
- children = input_name_to_nodes[mul_after_tanh.output[0]]
- if len(children) != 1 or children[0].op_type != "Mul":
- return
- mul_half = children[0]
- if not self.model.has_constant_input(mul_half, 0.5):
- return
- root_input = mul_after_tanh.input[0 if mul_after_tanh.input[1] == add_after_tanh.output[0] else 1]
- mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
- if mul_before_tanh is None:
- return
- k = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.01)
- if k < 0:
- return
- add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if k == 1 else 1, output_name_to_node)
- if add_before_tanh is None:
- return
- if add_before_tanh.input[0] == root_input:
- another = 1
- elif add_before_tanh.input[1] == root_input:
- another = 0
- else:
- return
- mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", another, output_name_to_node)
- if mul_after_pow is None:
- return
- m = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.01)
- if m < 0:
- return
- mul_cubed = self.model.match_parent(mul_after_pow, "Mul", 0 if m == 1 else 1, output_name_to_node)
- if mul_cubed is None:
- return
- if mul_cubed.input[0] == root_input:
- another = 1
- elif mul_cubed.input[1] == root_input:
- another = 0
- else:
- return
- mul_squared = self.model.match_parent(mul_cubed, "Mul", another, output_name_to_node)
- if mul_squared is None:
- return
- if mul_squared.input[0] != root_input or mul_squared.input[1] != root_input:
- return
- subgraph_nodes = [
- mul_squared,
- mul_cubed,
- mul_after_pow,
- add_before_tanh,
- mul_before_tanh,
- tanh_node,
- add_after_tanh,
- mul_after_tanh,
- mul_half,
- ]
- if not self.model.is_safe_to_fuse_nodes(
- subgraph_nodes,
- [mul_half.output[0]],
- input_name_to_nodes,
- output_name_to_node,
- ):
- return
- self.nodes_to_remove.extend(subgraph_nodes)
- fused_node = helper.make_node(
- "FastGelu",
- inputs=[root_input],
- outputs=mul_half.output,
- name=self.model.create_node_name("FastGelu"),
- )
- fused_node.domain = "com.microsoft"
- self.nodes_to_add.append(fused_node)
- self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
- self.increase_counter("FastGelu")
- return True
|