matmul_bnb4_quantizer.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import argparse
  7. import logging
  8. import os
  9. import numpy as np
  10. import numpy.typing as npt
  11. import onnx
  12. from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
  13. from onnxruntime.capi._pybind_state import quantize_matmul_bnb4
  14. from .onnx_model import ONNXModel
  15. from .quant_utils import attribute_to_kwarg
  16. logger = logging.getLogger(__name__)
  17. class MatMulBnb4Quantizer:
  18. """Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type"""
  19. ##################
  20. # quantization types, must be consistent with native code type
  21. # Bnb_DataType_t defined in blockwise_quant_block_bnb4.h
  22. # 4b floating point with bias of 3
  23. FP4 = 0
  24. # 4b NormalFloat
  25. NF4 = 1
  26. def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None):
  27. nodes_to_exclude = nodes_to_exclude or []
  28. assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4]
  29. self.model = ONNXModel(model)
  30. self.quant_type = quant_type
  31. self.block_size = block_size
  32. self.nodes_to_exclude = set(nodes_to_exclude)
  33. @staticmethod
  34. def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
  35. for gid in range(len(graph_path) - 1, -1, -1):
  36. graph = graph_path[gid]
  37. for tensor in graph.initializer:
  38. if tensor.name == name:
  39. return tensor, graph
  40. return None, None
  41. def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray:
  42. """4b quantize fp32/fp16 weight"""
  43. if len(fpweight.shape) != 2:
  44. raise ValueError("Current bnb4 block quantization only supports 2D tensors!")
  45. # need to copy since the transposed weight still has the original memory layout
  46. # Linear4bit quantizes its weight data which is the transposed weight
  47. fpweight_t = fpweight.transpose().copy()
  48. rows, cols = fpweight.shape
  49. numel = rows * cols
  50. block_size = self.block_size
  51. num_blocks = (numel + block_size - 1) // block_size
  52. quantized_numel = (numel + 1) // 2
  53. packed = np.zeros(quantized_numel, dtype="uint8")
  54. absmax = np.zeros(num_blocks, dtype=fpweight.dtype)
  55. # block wise quantization, fpweight_t is flattened and divided into blocks
  56. quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows)
  57. return (packed, absmax)
  58. def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto:
  59. """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
  60. if node.op_type != "MatMul":
  61. return node # only care about MatMul for now
  62. logger.debug(f"start to quantize {node.name} ...")
  63. if node.name in self.nodes_to_exclude:
  64. logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
  65. return node
  66. inputB = node.input[1] # noqa: N806
  67. B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806
  68. if B is None:
  69. logger.debug("MatMul doesn't have const weight. Skip to quantize")
  70. return node # only care about constant weight
  71. B_array = onnx.numpy_helper.to_array(B) # noqa: N806
  72. if len(B_array.shape) != 2:
  73. logger.debug("MatMul weight is not 2D. Skip to quantize")
  74. return node # can only process 2-D matrix
  75. packed, absmax = self.bnb4_block_quant(B_array)
  76. B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806
  77. B_quant.name = B.name + "_Bnb4"
  78. for input in Bs_graph.input:
  79. if input.name == inputB:
  80. Bs_graph.input.remove(input)
  81. break
  82. absmax_tensor = onnx.numpy_helper.from_array(absmax)
  83. absmax_tensor.name = B.name + "_absmax"
  84. Bs_graph.initializer.extend([B_quant, absmax_tensor])
  85. kwargs = {}
  86. rows, cols = B_array.shape
  87. kwargs["K"] = rows
  88. kwargs["N"] = cols
  89. kwargs["block_size"] = self.block_size
  90. kwargs["quant_type"] = self.quant_type
  91. matmul_bnb4_node = onnx.helper.make_node(
  92. "MatMulBnb4",
  93. inputs=[node.input[0], B_quant.name, absmax_tensor.name],
  94. outputs=[node.output[0]],
  95. name=node.name + "_Bnb4" if node.name else "",
  96. domain="com.microsoft",
  97. **kwargs,
  98. )
  99. logger.debug(f"complete quantization of {node.name} ...")
  100. return matmul_bnb4_node
  101. def _process_subgraph(self, graph_stack: list[GraphProto]):
  102. new_nodes = []
  103. graph = graph_stack[-1]
  104. for node in graph.node:
  105. graph_attrs = [
  106. attr
  107. for attr in node.attribute
  108. if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
  109. ]
  110. if graph_attrs:
  111. kwargs = {}
  112. for attr in node.attribute:
  113. if attr.type == onnx.AttributeProto.GRAPH:
  114. # recursive call to take care of sub-graph
  115. graph_stack.append(attr.g)
  116. kv = {attr.name: self._process_subgraph(graph_stack)}
  117. elif attr.type == onnx.AttributeProto.GRAPHS:
  118. value = []
  119. for subgraph in attr.graphs:
  120. # recursive call to take care of sub-graph
  121. graph_stack.append(subgraph)
  122. value.extend([self._process_subgraph(graph_stack)])
  123. kv = {attr.name: value}
  124. else:
  125. kv = attribute_to_kwarg(attr)
  126. kwargs.update(kv)
  127. node = onnx.helper.make_node( # noqa: PLW2901
  128. node.op_type, node.input, node.output, name=node.name, **kwargs
  129. )
  130. new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack))
  131. graph.ClearField("node")
  132. graph.node.extend(new_nodes)
  133. graph_stack.pop()
  134. return graph
  135. def process(self):
  136. # use a stack to keep track of sub-graphs
  137. graph_stack = [self.model.graph()]
  138. opset_import = self.model.opset_import()
  139. has_ms_domain = False
  140. for opset in opset_import:
  141. if opset.domain == "com.microsoft":
  142. has_ms_domain = True
  143. if not has_ms_domain:
  144. opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
  145. self._process_subgraph(graph_stack)
  146. self.model.clean_initializers()
  147. def parse_args():
  148. parser = argparse.ArgumentParser(
  149. description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices.
  150. A weight matrix is partitioned into blocks, where each block is a contiguous
  151. subset inside the flattened transposed weight matrix. Each block is quantized
  152. into a set of 4b integers with an absolute value scaling factor.
  153. """
  154. )
  155. parser.add_argument("--input_model", required=True, help="Path to the input model file")
  156. parser.add_argument("--output_model", required=True, help="Path to the output model file")
  157. parser.add_argument(
  158. "--quant_type",
  159. required=False,
  160. default=1,
  161. choices=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],
  162. help="Quantization data type. 0: FP4, 1: NF4",
  163. )
  164. parser.add_argument(
  165. "--block_size",
  166. required=False,
  167. default=64,
  168. help="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64",
  169. )
  170. parser.add_argument("-v", "--verbose", required=False, action="store_true")
  171. parser.set_defaults(verbose=False)
  172. parser.add_argument(
  173. "--nodes_to_exclude",
  174. nargs="+",
  175. type=str,
  176. required=False,
  177. default=[],
  178. help="Specify the nodes to be excluded from quantization with node names",
  179. )
  180. return parser.parse_args()
  181. if __name__ == "__main__":
  182. args = parse_args()
  183. if args.verbose:
  184. logger.setLevel(logging.DEBUG)
  185. input_model_path = args.input_model
  186. output_model_path = args.output_model
  187. if os.path.exists(output_model_path):
  188. logger.error(f"file {output_model_path} already exists")
  189. raise Exception(f"file {output_model_path} already exists")
  190. model = onnx.load(input_model_path)
  191. quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude)
  192. quant.process()
  193. quant.model.save_model_to_file(output_model_path, True)