maxpool.py 927 B

12345678910111213141516171819202122232425262728293031323334
  1. from .direct_q8 import Direct8BitOp, QDQDirect8BitOp
  2. class QMaxPool(Direct8BitOp):
  3. def __init__(self, onnx_quantizer, onnx_node):
  4. super().__init__(onnx_quantizer, onnx_node)
  5. def quantize(self):
  6. node = self.node
  7. assert node.op_type == "MaxPool"
  8. # if version is less than 12, go to normal quantize.
  9. if self.quantizer.opset_version < 12:
  10. super(Direct8BitOp, self).quantize()
  11. return
  12. # Direct 8bits op
  13. return super().quantize()
  14. class QDQMaxPool(QDQDirect8BitOp):
  15. def __init__(self, onnx_quantizer, onnx_node):
  16. super().__init__(onnx_quantizer, onnx_node)
  17. def quantize(self):
  18. node = self.node
  19. assert node.op_type == "MaxPool"
  20. # if version is less than 12, just no change
  21. if self.quantizer.opset_version < 12:
  22. return
  23. # Direct 8bits op
  24. return super().quantize()