registry.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from .operators.activation import QDQRemovableActivation, QLinearActivation
  2. from .operators.argmax import QArgMax
  3. from .operators.attention import AttentionQuant
  4. from .operators.base_operator import QuantOperatorBase
  5. from .operators.binary_op import QLinearBinaryOp
  6. from .operators.concat import QLinearConcat
  7. from .operators.conv import ConvInteger, QDQConv, QLinearConv
  8. from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp
  9. from .operators.embed_layernorm import EmbedLayerNormalizationQuant
  10. from .operators.gather import GatherQuant, QDQGather
  11. from .operators.gavgpool import QGlobalAveragePool
  12. from .operators.gemm import QDQGemm, QLinearGemm
  13. from .operators.lstm import LSTMQuant
  14. from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
  15. from .operators.maxpool import QDQMaxPool, QMaxPool
  16. from .operators.norm import QDQNormalization
  17. from .operators.pad import QDQPad, QPad
  18. from .operators.pooling import QLinearPool
  19. from .operators.qdq_base_operator import QDQOperatorBase
  20. from .operators.resize import QDQResize, QResize
  21. from .operators.softmax import QLinearSoftmax
  22. from .operators.split import QDQSplit, QSplit
  23. from .operators.where import QDQWhere, QLinearWhere
  24. from .quant_utils import QuantizationMode
  25. CommonOpsRegistry = {
  26. "Gather": GatherQuant,
  27. "Transpose": Direct8BitOp,
  28. "EmbedLayerNormalization": EmbedLayerNormalizationQuant,
  29. }
  30. IntegerOpsRegistry = {
  31. "Conv": ConvInteger,
  32. "MatMul": MatMulInteger,
  33. "Attention": AttentionQuant,
  34. "LSTM": LSTMQuant,
  35. }
  36. IntegerOpsRegistry.update(CommonOpsRegistry)
  37. QLinearOpsRegistry = {
  38. "ArgMax": QArgMax,
  39. "Conv": QLinearConv,
  40. "Gemm": QLinearGemm,
  41. "MatMul": QLinearMatMul,
  42. "Add": QLinearBinaryOp,
  43. "Mul": QLinearBinaryOp,
  44. "Relu": QLinearActivation,
  45. "Clip": QLinearActivation,
  46. "LeakyRelu": QLinearActivation,
  47. "Sigmoid": QLinearActivation,
  48. "MaxPool": QMaxPool,
  49. "GlobalAveragePool": QGlobalAveragePool,
  50. "Split": QSplit,
  51. "Pad": QPad,
  52. "Reshape": Direct8BitOp,
  53. "Squeeze": Direct8BitOp,
  54. "Unsqueeze": Direct8BitOp,
  55. "Resize": QResize,
  56. "AveragePool": QLinearPool,
  57. "Concat": QLinearConcat,
  58. "Softmax": QLinearSoftmax,
  59. "Where": QLinearWhere,
  60. }
  61. QLinearOpsRegistry.update(CommonOpsRegistry)
  62. QDQRegistry = {
  63. "Conv": QDQConv,
  64. "ConvTranspose": QDQConv,
  65. "Gemm": QDQGemm,
  66. "Clip": QDQRemovableActivation,
  67. "Relu": QDQRemovableActivation,
  68. "Reshape": QDQDirect8BitOp,
  69. "Transpose": QDQDirect8BitOp,
  70. "Squeeze": QDQDirect8BitOp,
  71. "Unsqueeze": QDQDirect8BitOp,
  72. "Resize": QDQResize,
  73. "MaxPool": QDQMaxPool,
  74. "AveragePool": QDQDirect8BitOp,
  75. "Slice": QDQDirect8BitOp,
  76. "Pad": QDQPad,
  77. "MatMul": QDQMatMul,
  78. "Split": QDQSplit,
  79. "Gather": QDQGather,
  80. "GatherElements": QDQGather,
  81. "Where": QDQWhere,
  82. "InstanceNormalization": QDQNormalization,
  83. "LayerNormalization": QDQNormalization,
  84. "BatchNormalization": QDQNormalization,
  85. "TopK": QDQDirect8BitOp,
  86. "CumSum": QDQOperatorBase,
  87. }
  88. def CreateDefaultOpQuantizer(onnx_quantizer, node): # noqa: N802
  89. return QuantOperatorBase(onnx_quantizer, node)
  90. def CreateOpQuantizer(onnx_quantizer, node): # noqa: N802
  91. registry = IntegerOpsRegistry if onnx_quantizer.mode == QuantizationMode.IntegerOps else QLinearOpsRegistry
  92. if node.op_type in registry:
  93. op_quantizer = registry[node.op_type](onnx_quantizer, node)
  94. if op_quantizer.should_quantize():
  95. return op_quantizer
  96. return QuantOperatorBase(onnx_quantizer, node)
  97. def CreateQDQQuantizer(onnx_quantizer, node): # noqa: N802
  98. if node.op_type in QDQRegistry:
  99. return QDQRegistry[node.op_type](onnx_quantizer, node)
  100. return QDQOperatorBase(onnx_quantizer, node)