norm.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from .qdq_base_operator import QDQOperatorBase
  6. class QDQNormalization(QDQOperatorBase):
  7. def __init__(self, onnx_quantizer, onnx_node):
  8. super().__init__(onnx_quantizer, onnx_node)
  9. def quantize(self):
  10. node = self.node
  11. assert node.op_type in {"InstanceNormalization", "LayerNormalization", "BatchNormalization"}
  12. # Input
  13. self.quantizer.quantize_activation_tensor(node.input[0])
  14. # Scale
  15. scale_is_initializer = self.quantizer.is_input_a_initializer(node.input[1])
  16. scale_is_per_channel, scale_channel_axis = self.quantizer.is_tensor_per_channel(
  17. node.input[1], default_axis=1, op_type=node.op_type
  18. )
  19. if scale_is_per_channel:
  20. self.quantizer.quantize_weight_tensor_per_channel(node.input[1], axis=scale_channel_axis)
  21. elif scale_is_initializer:
  22. self.quantizer.quantize_weight_tensor(node.input[1])
  23. else:
  24. self.quantizer.quantize_activation_tensor(node.input[1])
  25. # Bias
  26. if len(node.input) > 2 and node.input[2]:
  27. self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1])
  28. # Output
  29. if not self.disable_qdq_for_node_output:
  30. for output_name in node.output:
  31. self.quantizer.quantize_activation_tensor(output_name)