gemm.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import logging
  2. import numpy as np # noqa: F401
  3. import onnx
  4. from ..quant_utils import (
  5. TENSOR_NAME_QUANT_SUFFIX,
  6. QuantizedValue,
  7. QuantizedValueType,
  8. attribute_to_kwarg,
  9. find_by_name, # noqa: F401
  10. get_mul_node, # noqa: F401
  11. ms_domain,
  12. )
  13. from .base_operator import QuantOperatorBase # noqa: F401
  14. from .matmul import QOpMatMul
  15. from .qdq_base_operator import QDQOperatorBase
  16. def is_B_transposed(gemm_node): # noqa: N802
  17. transB_attribute = [attr for attr in gemm_node.attribute if attr.name == "transB"] # noqa: N806
  18. if transB_attribute:
  19. return onnx.helper.get_attribute_value(transB_attribute[0]) > 0
  20. return False
  21. def get_beta(gemm_node):
  22. beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"]
  23. if beta_attribute:
  24. return onnx.helper.get_attribute_value(beta_attribute[0])
  25. return 1.0
  26. def set_default_beta(gemm_node):
  27. beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"]
  28. if beta_attribute:
  29. beta_attribute[0].f = 1.0
  30. return 1.0
  31. class QLinearGemm(QOpMatMul):
  32. def __init__(self, onnx_quantizer, onnx_node):
  33. super().__init__(onnx_quantizer, onnx_node)
  34. def quantize(self):
  35. node = self.node
  36. assert node.op_type == "Gemm"
  37. (
  38. data_found,
  39. output_scale_name,
  40. output_zp_name,
  41. _,
  42. _,
  43. ) = self.quantizer._get_quantization_params(node.output[0])
  44. if self.quantizer.is_input_a_initializer(node.input[1]) and self.quantizer.is_per_channel():
  45. (
  46. quantized_input_names,
  47. zero_point_names,
  48. scale_names,
  49. nodes,
  50. ) = self.quantizer.quantize_activation(node, [0])
  51. quant_weight_tuple = self.quantizer.quantize_weight_per_channel(
  52. node.input[1],
  53. self.quantizer.weight_qType,
  54. 0 if is_B_transposed(node) else 1,
  55. )
  56. quantized_input_names.append(quant_weight_tuple[0])
  57. zero_point_names.append(quant_weight_tuple[1])
  58. scale_names.append(quant_weight_tuple[2])
  59. else:
  60. # Get Quantized from both activation(input[0]) and weight(input[1])
  61. (
  62. quantized_input_names,
  63. zero_point_names,
  64. scale_names,
  65. nodes,
  66. ) = self.quantizer.quantize_activation(node, [0])
  67. (
  68. quantized_input_names_weight,
  69. zero_point_names_weight,
  70. scale_names_weight,
  71. nodes_weight,
  72. ) = self.quantizer.quantize_weight(node, [1], reduce_range=self.quantizer.reduce_range)
  73. quantized_input_names.extend(quantized_input_names_weight)
  74. zero_point_names.extend(zero_point_names_weight)
  75. scale_names.extend(scale_names_weight)
  76. nodes.extend(nodes_weight)
  77. if not data_found or quantized_input_names is None:
  78. return super().quantize()
  79. quantized_bias_name = ""
  80. if len(node.input) == 3:
  81. if not self.quantizer.is_input_a_initializer(node.input[2]):
  82. return super().quantize()
  83. # Note: if the quantized type is float 8, the bias is converted into float 16.
  84. # cublasLtMatMul only supports (b)float16 or float32 bias.
  85. quantized_bias_name = self.quantizer.quantize_bias_static(
  86. node.input[2], node.input[0], node.input[1], get_beta(self.node)
  87. )
  88. qgemm_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
  89. qgemm_name = node.name + "_quant" if node.name else ""
  90. kwargs = {}
  91. for attribute in node.attribute:
  92. if attribute.name != "beta":
  93. kwargs.update(attribute_to_kwarg(attribute))
  94. kwargs["domain"] = ms_domain
  95. # generate input
  96. qgemm_inputs = []
  97. for i in range(2):
  98. qgemm_inputs.extend([quantized_input_names[i], scale_names[i], zero_point_names[i]])
  99. qgemm_inputs.extend([quantized_bias_name, output_scale_name, output_zp_name])
  100. qgemm_node = onnx.helper.make_node("QGemm", qgemm_inputs, [qgemm_output], qgemm_name, **kwargs)
  101. nodes.append(qgemm_node)
  102. # Create an entry for this quantized value
  103. q_output = QuantizedValue(
  104. node.output[0],
  105. qgemm_output,
  106. output_scale_name,
  107. output_zp_name,
  108. QuantizedValueType.Input,
  109. node_type=node.op_type,
  110. node_qtype=self.quantizer.weight_qType,
  111. )
  112. self.quantizer.quantized_value_map[node.output[0]] = q_output
  113. self.quantizer.new_nodes += nodes
  114. class QDQGemm(QDQOperatorBase):
  115. def __init__(self, onnx_quantizer, onnx_node):
  116. super().__init__(onnx_quantizer, onnx_node)
  117. def quantize(self):
  118. node = self.node
  119. assert node.op_type == "Gemm"
  120. self.quantizer.quantize_activation_tensor(node.input[0])
  121. if not self.disable_qdq_for_node_output:
  122. self.quantizer.quantize_activation_tensor(node.output[0])
  123. is_weight_per_channel, weight_axis = self.quantizer.is_tensor_per_channel(
  124. node.input[1], default_axis=0 if is_B_transposed(node) else 1
  125. )
  126. if is_weight_per_channel:
  127. self.quantizer.quantize_weight_tensor_per_channel(node.input[1], weight_axis)
  128. else:
  129. self.quantizer.quantize_weight_tensor(node.input[1])
  130. if len(node.input) == 3:
  131. if self.quantizer.is_input_a_initializer(node.input[2]):
  132. self.quantizer.quantize_bias_tensor(
  133. node.name, node.input[2], node.input[0], node.input[1], get_beta(self.node)
  134. )
  135. set_default_beta(self.node)
  136. else:
  137. logging.warning(
  138. f"Bias of Gemm node '{self.node.name}' is not constant. Please exclude this node for better performance."
  139. )