matmul.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import itertools
  2. import logging
  3. import onnx
  4. from onnx import onnx_pb as onnx_proto
  5. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, find_by_name, get_mul_node
  6. from .base_operator import QuantOperatorBase
  7. from .qdq_base_operator import QDQOperatorBase
  8. class QOpMatMul(QuantOperatorBase):
  9. def __init__(self, onnx_quantizer, onnx_node):
  10. super().__init__(onnx_quantizer, onnx_node)
  11. def should_quantize(self):
  12. if not self.quantizer.should_quantize_node(self.node):
  13. logging.debug(f"Ignore MatMul {self.node.name}]")
  14. return False
  15. if (not self.quantizer.is_float_tensor(self.node.input[1])) and (
  16. not self.quantizer.is_float_tensor(self.node.input[0])
  17. ):
  18. logging.info(f"Ignore MatMul due to non float inputs {self.node.name}]")
  19. return False
  20. # do not quantize non-constant B matrices for matmul
  21. if self.quantizer.q_matmul_const_b_only:
  22. if not self.quantizer.find_initializer_in_path(self.node.input[1]):
  23. logging.info(f"Ignore MatMul due to non constant B: {self.quantizer.graph_scope}[{self.node.name}]")
  24. return False
  25. return True
  26. """
  27. Used when quantize mode is QuantizationMode.IntegerOps.
  28. """
  29. class MatMulInteger(QOpMatMul):
  30. def __init__(self, onnx_quantizer, onnx_node):
  31. super().__init__(onnx_quantizer, onnx_node)
  32. def quantize(self):
  33. node = self.node
  34. assert node.op_type == "MatMul"
  35. # Get Quantized from both activation(input[0]) and weight(input[1])
  36. (
  37. quantized_input_names,
  38. zero_point_names,
  39. scale_names,
  40. nodes,
  41. ) = self.quantizer.quantize_activation(node, [0])
  42. (
  43. quantized_input_names_weight,
  44. zero_point_names_weight,
  45. scale_names_weight,
  46. nodes_weight,
  47. ) = self.quantizer.quantize_weight(node, [1], reduce_range=True, op_level_per_channel=True)
  48. quantized_input_names.extend(quantized_input_names_weight)
  49. zero_point_names.extend(zero_point_names_weight)
  50. scale_names.extend(scale_names_weight)
  51. nodes.extend(nodes_weight)
  52. matmul_integer_output = node.output[0] + "_output_quantized"
  53. matmul_integer_name = node.name + "_quant" if node.name else ""
  54. matmul_integer_node = onnx.helper.make_node(
  55. "MatMulInteger",
  56. quantized_input_names + zero_point_names,
  57. [matmul_integer_output],
  58. matmul_integer_name,
  59. )
  60. nodes.append(matmul_integer_node)
  61. # Add cast operation to cast matmulInteger output to float.
  62. cast_op_output = matmul_integer_output + "_cast_output"
  63. otype = self.quantizer.get_tensor_type(node.output[0], mandatory=True)
  64. cast_node = onnx.helper.make_node(
  65. "Cast",
  66. [matmul_integer_output],
  67. [cast_op_output],
  68. matmul_integer_output + "_cast",
  69. to=otype,
  70. )
  71. nodes.append(cast_node)
  72. # Add mul operation to multiply scales of two inputs.
  73. assert len(scale_names) == 2
  74. scales_mul_op = (
  75. matmul_integer_name + "_scales_mul"
  76. if matmul_integer_name
  77. else scale_names[0] + "_" + scale_names[1] + "_mul"
  78. )
  79. scales_mul_node = find_by_name(scales_mul_op, self.quantizer.new_nodes)
  80. if scales_mul_node is None:
  81. scales_mul_node = get_mul_node(scale_names, scales_mul_op + ":0", scales_mul_op)
  82. nodes.append(scales_mul_node)
  83. scales_mul_op_output = scales_mul_node.output[0]
  84. # Add mul operation to multiply mul_scales_op result with output of MatMulInteger
  85. # and make the output of this node the same as output of original matmul node.
  86. output_scale_mul_op = ""
  87. if matmul_integer_name:
  88. output_scale_mul_op = matmul_integer_name + "_output_scale_mul"
  89. nodes.append(
  90. get_mul_node(
  91. [cast_op_output, scales_mul_op_output],
  92. node.output[0],
  93. output_scale_mul_op,
  94. )
  95. )
  96. self.quantizer.new_nodes += nodes
  97. """
  98. Used when quantize mode is QuantizationMode.QLinearOps
  99. """
  100. class QLinearMatMul(QOpMatMul):
  101. def __init__(self, onnx_quantizer, onnx_node):
  102. super().__init__(onnx_quantizer, onnx_node)
  103. def quantize(self):
  104. node = self.node
  105. assert node.op_type == "MatMul"
  106. # Get Quantized from both activation(input[0]) and weight(input[1])
  107. (
  108. quantized_input_names,
  109. zero_point_names,
  110. scale_names,
  111. nodes,
  112. ) = self.quantizer.quantize_activation(node, [0])
  113. (
  114. quantized_input_names_weight,
  115. zero_point_names_weight,
  116. scale_names_weight,
  117. nodes_weight,
  118. ) = self.quantizer.quantize_weight(node, [1], reduce_range=True, op_level_per_channel=True)
  119. quantized_input_names.extend(quantized_input_names_weight)
  120. zero_point_names.extend(zero_point_names_weight)
  121. scale_names.extend(scale_names_weight)
  122. nodes.extend(nodes_weight)
  123. (
  124. data_found,
  125. output_scale_name,
  126. output_zp_name,
  127. _,
  128. _,
  129. ) = self.quantizer._get_quantization_params(node.output[0])
  130. if not data_found or quantized_input_names is None:
  131. return super().quantize()
  132. qlinear_matmul_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
  133. qlinear_matmul_name = node.name + "_quant" if node.name else ""
  134. qlinear_matmul_inputs = []
  135. # Input 0
  136. qlinear_matmul_inputs.append(quantized_input_names[0])
  137. qlinear_matmul_inputs.append(scale_names[0])
  138. qlinear_matmul_inputs.append(zero_point_names[0])
  139. # Input 1
  140. qlinear_matmul_inputs.append(quantized_input_names[1])
  141. qlinear_matmul_inputs.append(scale_names[1])
  142. qlinear_matmul_inputs.append(zero_point_names[1])
  143. # Output quantization parameter
  144. qlinear_matmul_inputs.append(output_scale_name)
  145. qlinear_matmul_inputs.append(output_zp_name)
  146. domain = (
  147. "com.microsoft"
  148. if self.quantizer.weight_qType
  149. in {
  150. onnx_proto.TensorProto.FLOAT8E4M3FN,
  151. onnx_proto.TensorProto.FLOAT8E4M3FNUZ,
  152. onnx_proto.TensorProto.FLOAT8E5M2,
  153. onnx_proto.TensorProto.FLOAT8E5M2FNUZ,
  154. }
  155. else ""
  156. )
  157. qlinear_matmul_node = onnx.helper.make_node(
  158. "QLinearMatMul",
  159. qlinear_matmul_inputs,
  160. [qlinear_matmul_output],
  161. qlinear_matmul_name,
  162. domain=domain,
  163. )
  164. nodes.append(qlinear_matmul_node)
  165. # Create an entry for this quantized value
  166. q_output = QuantizedValue(
  167. node.output[0],
  168. qlinear_matmul_output,
  169. output_scale_name,
  170. output_zp_name,
  171. QuantizedValueType.Input,
  172. )
  173. self.quantizer.quantized_value_map[node.output[0]] = q_output
  174. self.quantizer.new_nodes += nodes
  175. class QDQMatMul(QDQOperatorBase):
  176. def __init__(self, onnx_quantizer, onnx_node):
  177. super().__init__(onnx_quantizer, onnx_node)
  178. def quantize(self):
  179. node = self.node
  180. assert node.op_type == "MatMul"
  181. if self.disable_qdq_for_node_output:
  182. nodes_to_iterate = node.input
  183. else:
  184. nodes_to_iterate = itertools.chain(node.input, node.output)
  185. for tensor_name in nodes_to_iterate:
  186. if find_by_name(tensor_name, self.quantizer.model.initializer()):
  187. is_per_channel, channel_axis = self.quantizer.is_tensor_per_channel(
  188. tensor_name, default_axis=1, op_type=node.op_type
  189. )
  190. if is_per_channel:
  191. self.quantizer.quantize_weight_tensor_per_channel(tensor_name, channel_axis)
  192. else:
  193. self.quantizer.quantize_weight_tensor(tensor_name)
  194. else:
  195. self.quantizer.quantize_activation_tensor(tensor_name)