| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- from collections.abc import Sequence
- from logging import getLogger
- from typing import Any
- import numpy as np
- import onnx
- from onnx import helper
- from onnx_model import OnnxModel
- logger = getLogger(__name__)
- class DynamoOnnxHelper:
- """
- Helper class for processing ONNX models exported by Torch Dynamo.
- """
- def __init__(self, model: onnx.ModelProto):
- self.model = OnnxModel(model)
- def update_edges(self, edge_mapping: dict) -> None:
- """
- Updates the edges in the model according to the given mapping.
- """
- for node in self.model.model.graph.node:
- for i in range(len(node.input)):
- if node.input[i] in edge_mapping:
- node.input[i] = edge_mapping[node.input[i]]
- for i in range(len(node.output)):
- if node.output[i] in edge_mapping:
- node.output[i] = edge_mapping[node.output[i]]
- for graph_input in self.model.model.graph.input:
- if graph_input.name in edge_mapping:
- graph_input.name = edge_mapping[graph_input.name]
- for graph_output in self.model.model.graph.output:
- if graph_output.name in edge_mapping:
- graph_output.name = edge_mapping[graph_output.name]
- def unroll_function(self, func_name: str) -> None:
- """
- Unrolls the function with the given name in the model.
- """
- logger.debug(f"Unrolling function {func_name}...")
- nodes_to_remove = []
- nodes_to_add = []
- edges_to_remove = []
- edges_to_add = []
- for node in self.model.model.graph.node:
- if node.op_type == func_name:
- nodes_to_remove.append(node)
- edges_to_remove.extend(list(node.input) + list(node.output))
- func_to_remove = None
- for f in self.model.model.functions:
- if f.name == func_name:
- nodes_to_add.extend(list(f.node))
- edges_to_add.extend(list(f.input) + list(f.output))
- func_to_remove = f
- assert len(edges_to_remove) == len(edges_to_add)
- for node in nodes_to_remove:
- self.model.model.graph.node.remove(node)
- for node in nodes_to_add:
- self.model.model.graph.node.append(node)
- if func_to_remove is not None:
- self.model.model.functions.remove(func_to_remove)
- edge_mapping = {}
- for i in range(len(edges_to_remove)):
- k = edges_to_remove[i]
- v = edges_to_add[i]
- if k != v:
- edge_mapping[k] = v
- return self.update_edges(edge_mapping)
- def remove_function(self, func_name: str, input_id: int, output_id: int) -> None:
- """
- Removes the function in the model.
- """
- edge_mapping = {}
- nodes_to_remove = []
- for node in self.model.model.graph.node:
- if node.op_type.find(func_name) != -1:
- edge_mapping[node.input[input_id]] = node.output[output_id]
- nodes_to_remove.append(node)
- for node in nodes_to_remove:
- self.model.model.graph.node.remove(node)
- self.update_edges(edge_mapping)
- def remove_dropout_layer(self) -> None:
- """
- Removes the dropout layer in the model.
- """
- logger.debug("Removing dropout layer...")
- self.remove_function("Dropout", 0, 0)
- def remove_lm_head_layer(self) -> None:
- """
- Removes the LM head layer in the model.
- """
- logger.debug("Removing LM head layer...")
- # bugbug: need to copy the right vi over
- self.remove_function("Linear_lm_head", 2, 0)
- def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
- if raw:
- np_type = helper.tensor_dtype_to_np_dtype(data_type)
- if not isinstance(vals, np.ndarray):
- bytes = np.array(vals, dtype=np_type).tobytes()
- else:
- bytes = vals.astype(np_type).tobytes()
- tensor = helper.make_tensor(
- name=name,
- data_type=data_type,
- dims=dims,
- vals=bytes,
- raw=True,
- )
- else:
- tensor = helper.make_tensor(
- name=name,
- data_type=data_type,
- dims=dims,
- vals=vals,
- raw=False,
- )
- self.model.add_initializer(tensor)
- return tensor
- def convert_constants_to_initializers(self, min_size: int = 1) -> None:
- """
- Converts Constant ops of size [min_size] or higher to initializers
- """
- logger.debug(f"Converting constants greater than size {min_size} to initializers")
- constant_nodes = self.model.get_nodes_by_op_type("Constant")
- nodes_to_remove = []
- for node in constant_nodes:
- # Get info from Constant op
- np_data = self.model.get_constant_value(node.output[0])
- # Skip if there are less than [min_size] elements
- if np_data is None or np_data.size < min_size:
- continue
- # Add new initializer with same name as Constant op's output
- for att in node.attribute:
- if att.name == "value":
- self.add_initializer(
- name=node.output[0],
- data_type=att.t.data_type,
- dims=list(np_data.shape),
- vals=np_data,
- )
- break
- nodes_to_remove.append(node)
- # Remove Constant ops from graph
- self.model.remove_nodes(nodes_to_remove)
- def clear_metadata(self) -> None:
- """
- Clear metadata fields in all nodes
- """
- for graph in self.model.graphs():
- graph.ClearField("metadata_props")
- for node in self.model.nodes():
- node.ClearField("metadata_props")
- @staticmethod
- def fold_transpose_initializers(model) -> None:
- """
- Constant fold Transpose initializers without changing the initializer names
- """
- from onnxscript import ir # noqa: PLC0415
- for name, initializer in model.graph.initializers.items():
- user_nodes = initializer.consumers()
- if len(user_nodes) == 1 and user_nodes[0].op_type == "Transpose":
- transpose_node = user_nodes[0]
- perm = transpose_node.attributes.get("perm")
- if perm is None:
- transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose())
- else:
- transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose(perm.as_ints()))
- new_initializer = ir.Value(
- name=initializer.name,
- shape=transposed_tensor.shape,
- type=ir.TensorType(transposed_tensor.dtype),
- const_value=transposed_tensor,
- )
- ir.convenience.replace_all_uses_with(transpose_node.outputs[0], new_initializer)
- model.graph.initializers[name] = new_initializer
- transpose_node.graph.remove(transpose_node, safe=True)
|