| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import onnx
- from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
- from .base_operator import QuantOperatorBase
- from .qdq_base_operator import QDQOperatorBase
- class QLinearWhere(QuantOperatorBase):
- def should_quantize(self):
- return True
- def quantize(self):
- node = self.node
- assert node.op_type == "Where"
- if not self.quantizer.force_quantize_no_input_check:
- self.quantizer.new_nodes += [node]
- return
- (
- data_found,
- output_scale_name,
- output_zp_name,
- _,
- _,
- ) = self.quantizer._get_quantization_params(node.output[0])
- (
- q_input_names,
- zero_point_names,
- scale_names,
- nodes,
- ) = self.quantizer.quantize_activation(node, [1, 2])
- if not data_found or q_input_names is None:
- return super().quantize()
- qlinear_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
- qlinear_output_name = node.name + "_quant" if node.name else ""
- q_output = QuantizedValue(
- node.output[0],
- qlinear_output,
- output_scale_name,
- output_zp_name,
- QuantizedValueType.Input,
- )
- self.quantizer.quantized_value_map[node.output[0]] = q_output
- kwargs = {}
- for attribute in node.attribute:
- kwargs.update(attribute_to_kwarg(attribute))
- kwargs["domain"] = ms_domain
- qlwhere_inputs = [
- node.input[0],
- q_input_names[0],
- scale_names[0],
- zero_point_names[0],
- q_input_names[1],
- scale_names[1],
- zero_point_names[1],
- output_scale_name,
- output_zp_name,
- ]
- qlwhere_node = onnx.helper.make_node(
- "QLinearWhere", qlwhere_inputs, [qlinear_output], qlinear_output_name, **kwargs
- )
- self.quantizer.new_nodes += nodes
- self.quantizer.new_nodes += [qlwhere_node]
- class QDQWhere(QDQOperatorBase):
- def quantize(self):
- node = self.node
- assert node.op_type == "Where"
- if self.quantizer.force_quantize_no_input_check:
- if not self.quantizer.is_tensor_quantized(node.input[1]):
- self.quantizer.quantize_activation_tensor(node.input[1])
- if not self.quantizer.is_tensor_quantized(node.input[2]):
- self.quantizer.quantize_activation_tensor(node.input[2])
- if not self.disable_qdq_for_node_output:
- for output in node.output:
- self.quantizer.quantize_activation_tensor(output)
- elif (
- self.quantizer.is_tensor_quantized(node.input[1])
- and self.quantizer.is_tensor_quantized(node.input[2])
- and not self.disable_qdq_for_node_output
- ):
- for output in node.output:
- self.quantizer.quantize_activation_tensor(output)
|