activation.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import onnx
  2. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
  3. from .base_operator import QuantOperatorBase
  4. from .qdq_base_operator import QDQOperatorBase
  5. class QLinearActivation(QuantOperatorBase):
  6. def __init__(self, onnx_quantizer, onnx_node):
  7. super().__init__(onnx_quantizer, onnx_node)
  8. def QuantizeClipRelu(self): # noqa: N802
  9. node = self.node
  10. assert node.op_type == "Relu" or node.op_type == "Clip"
  11. # When mode is QLinearOps, the output quantization params are calculated based on outputs from
  12. # activation nodes, therefore these nodes can be removed from the graph if they follow a quantized op.
  13. # If input to this node is not quantized then keep this node
  14. # If activation is symmetric, not quantize the op and simply return
  15. if node.input[0] not in self.quantizer.quantized_value_map or self.quantizer.is_activation_symmetric:
  16. return super().quantize()
  17. quantized_value = self.quantizer.quantized_value_map[node.input[0]]
  18. self.quantizer.quantized_value_map[node.output[0]] = quantized_value
  19. def quantize(self):
  20. node = self.node
  21. if node.op_type == "Relu" or node.op_type == "Clip":
  22. self.QuantizeClipRelu()
  23. return
  24. nnapi_sigmoid_option = "extra.Sigmoid.nnapi"
  25. sigmoid_nnapi_mode = (
  26. node.op_type == "Sigmoid"
  27. and nnapi_sigmoid_option in self.quantizer.extra_options
  28. and self.quantizer.extra_options[nnapi_sigmoid_option]
  29. )
  30. use_scale = 1 / 256.0 if sigmoid_nnapi_mode else None
  31. use_zeropoint = 0 if sigmoid_nnapi_mode else None
  32. # No assert on op_type as it is controlled by registry
  33. # only try to quantize when given quantization parameters for it
  34. (
  35. data_found,
  36. output_scale_name,
  37. output_zp_name,
  38. _,
  39. _,
  40. ) = self.quantizer._get_quantization_params(node.output[0], use_scale, use_zeropoint)
  41. (
  42. quantized_input_names,
  43. zero_point_names,
  44. scale_names,
  45. nodes,
  46. ) = self.quantizer.quantize_activation(node, [0])
  47. if not data_found or quantized_input_names is None:
  48. return super().quantize()
  49. qlinear_activation_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
  50. qlinear_activation_name = ""
  51. if node.name:
  52. qlinear_activation_name = node.name + "_quant"
  53. kwargs = {}
  54. for attribute in node.attribute:
  55. kwargs.update(attribute_to_kwarg(attribute))
  56. kwargs["domain"] = ms_domain
  57. qlinear_activation_inputs = [
  58. quantized_input_names[0],
  59. scale_names[0],
  60. zero_point_names[0],
  61. output_scale_name,
  62. output_zp_name,
  63. ]
  64. qlinear_activation_node = onnx.helper.make_node(
  65. "QLinear" + node.op_type,
  66. qlinear_activation_inputs,
  67. [qlinear_activation_output],
  68. qlinear_activation_name,
  69. **kwargs,
  70. )
  71. # Create an entry for this quantized value
  72. q_output = QuantizedValue(
  73. node.output[0],
  74. qlinear_activation_output,
  75. output_scale_name,
  76. output_zp_name,
  77. QuantizedValueType.Input,
  78. )
  79. self.quantizer.quantized_value_map[node.output[0]] = q_output
  80. nodes.append(qlinear_activation_node)
  81. self.quantizer.new_nodes += nodes
  82. class QDQRemovableActivation(QDQOperatorBase):
  83. def __init__(self, onnx_quantizer, onnx_node):
  84. super().__init__(onnx_quantizer, onnx_node)
  85. def quantize(self):
  86. node = self.node
  87. # If input to this node is not quantized then keep this node
  88. if not self.quantizer.is_tensor_quantized(node.input[0]):
  89. return
  90. if (
  91. not self.quantizer.is_activation_symmetric
  92. and not self.quantizer.qdq_keep_removable_activations
  93. and self.quantizer.try_replacing_upstream_output(node.input[0], node.output[0])
  94. ):
  95. self.quantizer.remove_node(self.node)
  96. else:
  97. self.quantizer.quantize_activation_tensor(node.input[0])
  98. if not self.disable_qdq_for_node_output:
  99. self.quantizer.quantize_activation_tensor(node.output[0])