| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import onnx
- from onnx import onnx_pb as onnx_proto # noqa: F401
- from ..quant_utils import attribute_to_kwarg, ms_domain
- from .base_operator import QuantOperatorBase
- """
- Quantize Attention
- """
- class AttentionQuant(QuantOperatorBase):
- def __init__(self, onnx_quantizer, onnx_node):
- super().__init__(onnx_quantizer, onnx_node)
- def should_quantize(self):
- return self.quantizer.should_quantize_node(self.node)
- def quantize(self):
- """
- parameter node: Attention node.
- parameter new_nodes_list: List of new nodes created before processing this node.
- return: a list of nodes in topological order that represents quantized Attention node.
- """
- node = self.node
- assert node.op_type == "Attention"
- # TODO This is a temporary fix to stop exporting QAttention with qkv_hidden_sizes
- # attribute. This needs to be removed once the QAttention for varied q,k,v sizes
- # is implemented
- for attr in node.attribute:
- if attr.name == "qkv_hidden_sizes":
- return super().quantize()
- (
- quantized_input_names,
- zero_point_names,
- scale_names,
- nodes,
- ) = self.quantizer.quantize_activation(node, [0])
- (
- quantized_input_names_weight,
- zero_point_names_weight,
- scale_names_weight,
- nodes_weight,
- ) = self.quantizer.quantize_weight(node, [1], reduce_range=True, op_level_per_channel=True)
- quantized_input_names.extend(quantized_input_names_weight)
- zero_point_names.extend(zero_point_names_weight)
- scale_names.extend(scale_names_weight)
- nodes.extend(nodes_weight)
- if quantized_input_names is None:
- return super().quantize()
- qattention_name = "" if not node.name else node.name + "_quant"
- inputs = []
- inputs.extend(quantized_input_names)
- inputs.extend([node.input[2]])
- inputs.extend(scale_names)
- inputs.extend([node.input[3] if len(node.input) > 3 else ""])
- inputs.extend(zero_point_names)
- inputs.extend([node.input[4] if len(node.input) > 4 else ""])
- kwargs = {}
- for attribute in node.attribute:
- kwargs.update(attribute_to_kwarg(attribute))
- kwargs["domain"] = ms_domain
- qattention_node = onnx.helper.make_node("QAttention", inputs, node.output, qattention_name, **kwargs)
- nodes.append(qattention_node)
- self.quantizer.new_nodes += nodes
|