concat.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import onnx
  2. from ..quant_utils import ( # noqa: F401
  3. TENSOR_NAME_QUANT_SUFFIX,
  4. QuantizedValue,
  5. QuantizedValueType,
  6. attribute_to_kwarg,
  7. ms_domain,
  8. )
  9. from .base_operator import QuantOperatorBase
  10. from .qdq_base_operator import QDQOperatorBase # noqa: F401
  11. class QLinearConcat(QuantOperatorBase):
  12. def __init__(self, onnx_quantizer, onnx_node):
  13. super().__init__(onnx_quantizer, onnx_node)
  14. def quantize(self):
  15. node = self.node
  16. (
  17. data_found,
  18. output_scale_name,
  19. output_zp_name,
  20. _,
  21. _,
  22. ) = self.quantizer._get_quantization_params(node.output[0])
  23. (
  24. q_input_names,
  25. zero_point_names,
  26. scale_names,
  27. nodes,
  28. ) = self.quantizer.quantize_activation(node, [*range(len(node.input))])
  29. if not data_found or q_input_names is None:
  30. return super().quantize()
  31. # Create an entry for output quantized value
  32. quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
  33. quantized_output_value = QuantizedValue(
  34. node.output[0],
  35. node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
  36. output_scale_name,
  37. output_zp_name,
  38. quantized_input_value.value_type,
  39. )
  40. self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
  41. kwargs = {}
  42. for attribute in node.attribute:
  43. kwargs.update(attribute_to_kwarg(attribute))
  44. kwargs["domain"] = ms_domain
  45. qnode_name = node.name + "_quant" if node.name else ""
  46. qlconcat_inputs = [output_scale_name, output_zp_name]
  47. for i in range(len(q_input_names)):
  48. qlconcat_inputs.extend([q_input_names[i], scale_names[i], zero_point_names[i]])
  49. qlconcat_node = onnx.helper.make_node(
  50. "QLinearConcat", qlconcat_inputs, [quantized_output_value.q_name], qnode_name, **kwargs
  51. )
  52. self.quantizer.new_nodes += nodes
  53. self.quantizer.new_nodes += [qlconcat_node]