quantize_helper.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. import onnx
  9. import torch
  10. from transformers.modeling_utils import Conv1D
  11. logger = logging.getLogger(__name__)
  12. def _conv1d_to_linear(module):
  13. in_size, out_size = module.weight.shape
  14. linear = torch.nn.Linear(in_size, out_size)
  15. linear.weight.data = module.weight.data.T.contiguous()
  16. linear.bias.data = module.bias.data
  17. return linear
  18. def conv1d_to_linear(model):
  19. """in-place
  20. This is for Dynamic Quantization, as Conv1D is not recognized by PyTorch, convert it to nn.Linear
  21. """
  22. logger.debug("replace Conv1D with Linear")
  23. for name in list(model._modules):
  24. module = model._modules[name]
  25. if isinstance(module, Conv1D):
  26. linear = _conv1d_to_linear(module)
  27. model._modules[name] = linear
  28. else:
  29. conv1d_to_linear(module)
  30. def _get_size_of_pytorch_model(model):
  31. torch.save(model.state_dict(), "temp.p")
  32. size = os.path.getsize("temp.p") / (1024 * 1024)
  33. os.remove("temp.p")
  34. return size
  35. class QuantizeHelper:
  36. @staticmethod
  37. def quantize_torch_model(model, dtype=torch.qint8):
  38. """
  39. Usage: model = quantize_model(model)
  40. TODO: mix of in-place and return, but results are different
  41. """
  42. conv1d_to_linear(model)
  43. quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=dtype)
  44. logger.info(f"Size of full precision Torch model(MB):{_get_size_of_pytorch_model(model)}")
  45. logger.info(f"Size of quantized Torch model(MB):{_get_size_of_pytorch_model(quantized_model)}")
  46. return quantized_model
  47. @staticmethod
  48. def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data_format=False):
  49. from pathlib import Path # noqa: PLC0415
  50. from onnxruntime.quantization import quantize_dynamic # noqa: PLC0415
  51. Path(quantized_model_path).parent.mkdir(parents=True, exist_ok=True)
  52. logger.info(f"Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path) / (1024 * 1024)}")
  53. quantize_dynamic(
  54. onnx_model_path,
  55. quantized_model_path,
  56. use_external_data_format=use_external_data_format,
  57. extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
  58. )
  59. logger.info(f"quantized model saved to:{quantized_model_path}")
  60. # TODO: inlcude external data in total model size.
  61. logger.info(f"Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path) / (1024 * 1024)}")