| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- import onnx
- from ..quant_utils import ( # noqa: F401
- TENSOR_NAME_QUANT_SUFFIX,
- QuantizedValue,
- QuantizedValueType,
- attribute_to_kwarg,
- ms_domain,
- )
- from .base_operator import QuantOperatorBase
- from .qdq_base_operator import QDQOperatorBase # noqa: F401
- class QLinearConcat(QuantOperatorBase):
- def __init__(self, onnx_quantizer, onnx_node):
- super().__init__(onnx_quantizer, onnx_node)
- def quantize(self):
- node = self.node
- (
- data_found,
- output_scale_name,
- output_zp_name,
- _,
- _,
- ) = self.quantizer._get_quantization_params(node.output[0])
- (
- q_input_names,
- zero_point_names,
- scale_names,
- nodes,
- ) = self.quantizer.quantize_activation(node, [*range(len(node.input))])
- if not data_found or q_input_names is None:
- return super().quantize()
- # Create an entry for output quantized value
- quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
- quantized_output_value = QuantizedValue(
- node.output[0],
- node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
- output_scale_name,
- output_zp_name,
- quantized_input_value.value_type,
- )
- self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
- kwargs = {}
- for attribute in node.attribute:
- kwargs.update(attribute_to_kwarg(attribute))
- kwargs["domain"] = ms_domain
- qnode_name = node.name + "_quant" if node.name else ""
- qlconcat_inputs = [output_scale_name, output_zp_name]
- for i in range(len(q_input_names)):
- qlconcat_inputs.extend([q_input_names[i], scale_names[i], zero_point_names[i]])
- qlconcat_node = onnx.helper.make_node(
- "QLinearConcat", qlconcat_inputs, [quantized_output_value.q_name], qnode_name, **kwargs
- )
- self.quantizer.new_nodes += nodes
- self.quantizer.new_nodes += [qlconcat_node]
|