matmul_nbits_quantizer.py 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import argparse
  8. import copy
  9. import logging
  10. import os
  11. import ml_dtypes
  12. import numpy as np
  13. import numpy.typing as npt
  14. import onnx
  15. import onnx_ir as ir
  16. from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
  17. from onnxruntime.capi._pybind_state import (
  18. quantize_matmul_2bits,
  19. quantize_matmul_4bits,
  20. quantize_matmul_8bits,
  21. quantize_qdq_matmul_4bits,
  22. )
  23. from .calibrate import CalibrationDataReader
  24. from .neural_compressor import gptq_quantize, rtn_quantize
  25. from .onnx_model import ONNXModel
  26. from .quant_utils import QuantFormat, attribute_to_kwarg
  27. logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO)
  28. logger = logging.getLogger(__name__)
  29. class WeightOnlyQuantConfig:
  30. def __init__(
  31. self,
  32. algorithm: str,
  33. quant_format: QuantFormat,
  34. op_types_to_quantize: tuple[str, ...] | None = None,
  35. quant_axes: tuple[tuple[str, int], ...] | None = None,
  36. customized_weight_config: dict | None = None,
  37. ):
  38. """This is the Base class for Weight Only blockwise quantization Configuration.
  39. Args:
  40. algorithm:
  41. weight only quantize algorithm name.
  42. quant_format: QuantFormat{QOperator, QDQ}.
  43. QOperator format quantizes the model with quantized operators directly.
  44. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  45. op_types_to_quantize (optional):
  46. set of operator types to quantize. Default {MatMul}
  47. quant_axes (dict[str, int], optional):
  48. op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
  49. customized_weight_config:
  50. customized weight config for nodes if needed. It is dictionary with node name as key,
  51. and the value is a dict of customized config.
  52. """
  53. self.algorithm = algorithm
  54. self.quant_format = quant_format
  55. self.op_types_to_quantize = set(op_types_to_quantize) if op_types_to_quantize else {"MatMul"}
  56. self.quant_axes = dict(quant_axes) if quant_axes else {"MatMul": 0, "Gather": 1}
  57. self.customized_weight_config = customized_weight_config
  58. class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
  59. def __init__(
  60. self,
  61. ratios=None,
  62. quant_format=QuantFormat.QOperator,
  63. op_types_to_quantize: tuple[str, ...] | None = None,
  64. customized_weight_config: dict | None = None,
  65. ):
  66. """
  67. This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration.
  68. RTN is the most straightforward way to quantize weight using scale maps.
  69. Args:
  70. ratios:
  71. percentile of clip. Defaults to {}.
  72. quant_format (QuantFormat{QOperator, QDQ}, optional):
  73. QOperator format quantizes the model with quantized operators directly.
  74. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  75. Defaults to QuantFormat.QOperator.
  76. op_types_to_quantize (optional):
  77. set of operator types to quantize.
  78. customized_weight_config:
  79. customized weight config for nodes if needed. It is dictionary with node name as key,
  80. and the value is a dict of customized config.
  81. """
  82. assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format"
  83. if ratios is None:
  84. ratios = {}
  85. super().__init__(
  86. algorithm="RTN",
  87. quant_format=quant_format,
  88. op_types_to_quantize=op_types_to_quantize,
  89. customized_weight_config=customized_weight_config,
  90. )
  91. self.ratios = ratios
  92. class KQuantWeightOnlyQuantConfig(WeightOnlyQuantConfig):
  93. def __init__(
  94. self,
  95. ratios=None,
  96. quant_format=QuantFormat.QOperator,
  97. op_types_to_quantize: tuple[str, ...] | None = None,
  98. customized_weight_config: dict | None = None,
  99. ):
  100. """
  101. This is a class for k-quant algorithm Weight Only Quant Configuration.
  102. Args:
  103. ratios:
  104. percentile of clip. Defaults to {}.
  105. quant_format (QuantFormat{QOperator, QDQ}, optional):
  106. QOperator format quantizes the model with quantized operators directly.
  107. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  108. Defaults to QuantFormat.QOperator.
  109. op_types_to_quantize (optional):
  110. set of operator types to quantize.
  111. """
  112. assert quant_format == QuantFormat.QOperator, "k-quant only supports QOperator format"
  113. if ratios is None:
  114. ratios = {}
  115. super().__init__(
  116. algorithm="k_quant",
  117. quant_format=quant_format,
  118. op_types_to_quantize=op_types_to_quantize,
  119. customized_weight_config=customized_weight_config,
  120. )
  121. self.ratios = ratios
  122. class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
  123. def __init__(
  124. self,
  125. calibration_data_reader: CalibrationDataReader | None = None,
  126. percdamp=0.01,
  127. block_size=128,
  128. actorder=False,
  129. mse=False,
  130. perchannel=True,
  131. quant_format=QuantFormat.QOperator,
  132. op_types_to_quantize: tuple[str, ...] | None = None,
  133. ):
  134. """
  135. This is a class for GPTQ algorithm Weight Only Quant Configuration.
  136. GPTQ algorithm provides more accurate quantization but requires more computational resources.
  137. Args:
  138. calibration_data_reader:
  139. a calibration data reader. It enumerates calibration data and generates inputs for the original model.
  140. percdamp:
  141. percent of the average Hessian diagonal to use for dampening.
  142. block_size (int, optional):
  143. channel number in one block to execute a GPTQ quantization iteration.
  144. actorder (bool, optional):
  145. whether rearrange Hessian matrix considering the diag's value.
  146. mse (bool, optional):
  147. whether get scale and zero point with mse error.
  148. perchannel (bool, optional):
  149. whether quantize weight per-channel.
  150. quant_format (QuantFormat{QOperator, QDQ}, optional):
  151. QOperator format quantizes the model with quantized operators directly.
  152. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  153. Defaults to QuantFormat.QOperator.
  154. op_types_to_quantize (optional):
  155. set of operator types to quantize.
  156. """
  157. assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format"
  158. super().__init__(
  159. algorithm="GPTQ",
  160. quant_format=quant_format,
  161. op_types_to_quantize=op_types_to_quantize,
  162. )
  163. self.calibration_data_reader = calibration_data_reader
  164. self.percdamp = percdamp
  165. self.block_size = block_size
  166. self.actorder = actorder
  167. self.mse = mse
  168. self.perchannel = perchannel
  169. class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
  170. def __init__(
  171. self,
  172. block_size=128,
  173. bits=4,
  174. axis=1,
  175. quant_format=QuantFormat.QOperator,
  176. op_types_to_quantize: tuple[str, ...] | None = None,
  177. quant_axes: tuple[tuple[str, int], ...] | None = None,
  178. ):
  179. """
  180. This is a class for HQQ algorithm Weight Only Quant Configuration.
  181. HQQ algorithm quant weight without needing calibrate data.
  182. Args:
  183. block_size (int, optional):
  184. channel number in one block to execute a HQQ quantization iteration.
  185. bits (int, optional):
  186. how many bits to represent weight.
  187. axis (int, optional):
  188. 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf
  189. quant_format (QuantFormat{QOperator, QDQ}, optional):
  190. QOperator format quantizes the model with quantized operators directly.
  191. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  192. Defaults to QuantFormat.QOperator.
  193. op_types_to_quantize (optional):
  194. set of operator types to quantize.
  195. quant_axes (dict[str, int], optional):
  196. op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
  197. """
  198. assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format"
  199. super().__init__(
  200. algorithm="HQQ",
  201. quant_format=quant_format,
  202. op_types_to_quantize=op_types_to_quantize,
  203. quant_axes=quant_axes,
  204. )
  205. self.block_size = block_size
  206. self.bits = bits
  207. self.axis = axis
  208. class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig):
  209. def __init__(
  210. self,
  211. block_size: int = 128,
  212. is_symmetric: bool = False,
  213. accuracy_level: int | None = None,
  214. quant_format=QuantFormat.QOperator,
  215. op_types_to_quantize: tuple[str, ...] | None = None,
  216. quant_axes: tuple[tuple[str, int], ...] | None = None,
  217. bits: int = 4,
  218. channel_wised_quantize: bool = False,
  219. ):
  220. """
  221. This is a class for weight only affine quantization configuration.
  222. Args:
  223. block_size (int, optional):
  224. channel number in one block to execute an affine quantization iteration.
  225. is_symmetric (bool, optional):
  226. whether quantize weight symmetrically.
  227. accuracy_level (int, optional):
  228. Accuracy level of the 4-bit quantized MatMul computation.
  229. Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details.
  230. (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)
  231. quant_format (QuantFormat{QOperator, QDQ}, optional):
  232. QOperator format quantizes the model with quantized operators directly.
  233. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  234. Defaults to QuantFormat.QOperator.
  235. op_types_to_quantize (optional):
  236. set of operator types to quantize.
  237. quant_axes (dict[str, int], optional):
  238. op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
  239. bits (int, optional):
  240. number of bits per element after quantization. Default 4.
  241. """
  242. super().__init__(
  243. algorithm="DEFAULT",
  244. quant_format=quant_format,
  245. op_types_to_quantize=op_types_to_quantize,
  246. quant_axes=quant_axes,
  247. )
  248. self.block_size = block_size
  249. self.is_symmetric = is_symmetric
  250. self.bits = bits
  251. self.accuracy_level = accuracy_level
  252. self.channel_wised_quantize = channel_wised_quantize
  253. if channel_wised_quantize and quant_format == QuantFormat.QOperator:
  254. raise NotImplementedError("QuantFormat.QOperator is not supported channel_wised_quantize yet")
  255. class NVAWQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
  256. def __init__(
  257. self,
  258. tokenizer_dir,
  259. dataset_name="cnn",
  260. cache_dir="./cache",
  261. calibration_method="awq_lite",
  262. ):
  263. """
  264. Configuration for the nvidia_awq quantization method.
  265. Args:
  266. tokenizer_dir (str): pathof the tokenizer dir.
  267. dataset_name (str): Name of the dataset.
  268. cache_dir (str): Directory for caching.
  269. calibration_method (str): calib method for nvidia_awq.
  270. """
  271. # Import torch and DataLoader
  272. try:
  273. import torch # noqa: PLC0415
  274. from torch.utils.data import DataLoader # noqa: PLC0415
  275. self.torch = torch
  276. self.DataLoader = DataLoader
  277. except ImportError:
  278. print(
  279. "Error: The 'torch' library is required but not installed. Please install it using 'pip install torch'."
  280. )
  281. raise ImportError("torch is not installed. Exiting.") from None
  282. # Import datasets
  283. try:
  284. from datasets import load_dataset # noqa: PLC0415
  285. self.load_dataset = load_dataset
  286. except ImportError:
  287. print(
  288. "Error: The 'datasets' library is required but not installed. Please install it using 'pip install datasets'."
  289. )
  290. raise ImportError("datasets is not installed. Exiting.") from None
  291. # Import transformers
  292. try:
  293. from transformers import AutoConfig, AutoTokenizer # noqa: PLC0415
  294. self.AutoConfig = AutoConfig
  295. self.AutoTokenizer = AutoTokenizer
  296. except ImportError:
  297. print(
  298. "Error: The 'transformers' library is required but not installed. Please install it using 'pip install transformers'."
  299. )
  300. raise ImportError("transformers is not installed. Exiting.") from None
  301. super().__init__(
  302. algorithm="nvidia_awq",
  303. quant_format=QuantFormat.QDQ,
  304. op_types_to_quantize=None, # Assuming op_types_to_quantize is handled elsewhere
  305. quant_axes=None, # Assuming quant_axes is handled elsewhere
  306. )
  307. # Determine the device
  308. device = self.torch.device("cuda" if self.torch.cuda.is_available() else "cpu")
  309. calib_inputs = self.get_calib_inputs(
  310. dataset_name=dataset_name,
  311. model_name=tokenizer_dir,
  312. cache_dir=cache_dir,
  313. calib_size=32,
  314. batch_size=1,
  315. block_size=512,
  316. device=device,
  317. use_fp16=True,
  318. use_buffer_share=False,
  319. add_past_kv_inputs=True,
  320. max_calib_rows_to_load=128,
  321. add_position_ids=True,
  322. )
  323. self.calibration_data_reader = calib_inputs
  324. self.calibration_method = calibration_method
  325. def make_model_input(
  326. self,
  327. config,
  328. input_ids_arg,
  329. attention_mask_arg,
  330. add_past_kv_inputs,
  331. device,
  332. use_fp16,
  333. use_buffer_share,
  334. add_position_ids,
  335. ):
  336. # Access torch from the instance variable
  337. torch = self.torch
  338. input_ids = input_ids_arg
  339. attention_mask = attention_mask_arg
  340. if isinstance(input_ids_arg, list):
  341. input_ids = torch.tensor(input_ids_arg, device=device, dtype=torch.int64)
  342. attention_mask = torch.tensor(attention_mask_arg, device=device, dtype=torch.int64)
  343. inputs = {
  344. "input_ids": input_ids.contiguous(),
  345. "attention_mask": attention_mask.contiguous(),
  346. }
  347. if add_position_ids:
  348. position_ids = attention_mask.long().cumsum(-1) - 1
  349. position_ids.masked_fill_(attention_mask == 0, 1)
  350. inputs["position_ids"] = position_ids.contiguous()
  351. if add_past_kv_inputs:
  352. torch_dtype = torch.float16 if use_fp16 else torch.float32
  353. batch_size, sequence_length = input_ids.shape
  354. max_sequence_length = config.max_position_embeddings
  355. num_heads, head_size = (
  356. config.num_key_value_heads,
  357. config.hidden_size // config.num_attention_heads,
  358. )
  359. for i in range(config.num_hidden_layers):
  360. past_key = torch.zeros(
  361. batch_size,
  362. num_heads,
  363. max_sequence_length if use_buffer_share else 0,
  364. head_size,
  365. device=device,
  366. dtype=torch_dtype,
  367. )
  368. past_value = torch.zeros(
  369. batch_size,
  370. num_heads,
  371. max_sequence_length if use_buffer_share else 0,
  372. head_size,
  373. device=device,
  374. dtype=torch_dtype,
  375. )
  376. inputs.update(
  377. {
  378. f"past_key_values.{i}.key": past_key.contiguous(),
  379. f"past_key_values.{i}.value": past_value.contiguous(),
  380. }
  381. )
  382. return inputs
  383. def get_calib_inputs(
  384. self,
  385. dataset_name,
  386. model_name,
  387. cache_dir,
  388. calib_size,
  389. batch_size,
  390. block_size,
  391. device,
  392. use_fp16,
  393. use_buffer_share,
  394. add_past_kv_inputs,
  395. max_calib_rows_to_load,
  396. add_position_ids,
  397. ):
  398. # Access transformers and datasets from the instance variables
  399. auto_config = self.AutoConfig
  400. auto_tokenizer = self.AutoTokenizer
  401. load_dataset = self.load_dataset
  402. config = auto_config.from_pretrained(
  403. model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True
  404. )
  405. tokenizer = auto_tokenizer.from_pretrained(
  406. model_name, use_auth_token=True, cache_dir=cache_dir, trust_remote_code=True
  407. )
  408. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  409. tokenizer.pad_token = tokenizer.eos_token
  410. assert calib_size <= max_calib_rows_to_load, "calib size should be no more than max_calib_rows_to_load"
  411. if "cnn" in dataset_name:
  412. dataset2 = load_dataset("cnn_dailymail", name="3.0.0", split="train").select(range(max_calib_rows_to_load))
  413. column = "article"
  414. elif "pile" in dataset_name:
  415. dataset2 = load_dataset("mit-han-lab/pile-val-backup", split="validation")
  416. column = "text"
  417. else:
  418. raise ValueError(f'dataset "{dataset_name}" not supported')
  419. dataset2 = dataset2[column][:calib_size]
  420. batch_encoded = tokenizer.batch_encode_plus(
  421. dataset2, return_tensors="pt", padding=True, truncation=True, max_length=block_size
  422. )
  423. batch_encoded = batch_encoded.to(device)
  424. batch_encoded_input_ids = batch_encoded["input_ids"]
  425. batch_encoded_attention_mask = batch_encoded["attention_mask"]
  426. # Access DataLoader from the instance variable
  427. data_loader = self.DataLoader
  428. calib_dataloader_input_ids = data_loader(batch_encoded_input_ids, batch_size=batch_size, shuffle=False)
  429. calib_dataloader_attention_mask = data_loader(
  430. batch_encoded_attention_mask, batch_size=batch_size, shuffle=False
  431. )
  432. assert len(calib_dataloader_input_ids.dataset) == len(calib_dataloader_attention_mask.dataset)
  433. assert len(calib_dataloader_input_ids) == len(calib_dataloader_attention_mask)
  434. number_of_batched_samples = calib_size // batch_size
  435. batched_input_ids = []
  436. for idx, data in enumerate(calib_dataloader_input_ids):
  437. batched_input_ids.append(data)
  438. if idx == (number_of_batched_samples - 1):
  439. break
  440. batched_attention_mask = []
  441. for idx, data in enumerate(calib_dataloader_attention_mask):
  442. batched_attention_mask.append(data)
  443. if idx == (number_of_batched_samples - 1):
  444. break
  445. print(
  446. f"\n--Quantize-Script-- number_of_batched_samples={number_of_batched_samples}, "
  447. f"batch-input-ids-list-len={len(batched_input_ids)}, batched_attention_mask={len(batched_attention_mask)}\n"
  448. )
  449. batched_inputs_list = []
  450. for i in range(number_of_batched_samples):
  451. input_ids = batched_input_ids[i]
  452. attention_mask = batched_attention_mask[i]
  453. inputs = self.make_model_input(
  454. config,
  455. input_ids,
  456. attention_mask,
  457. add_past_kv_inputs,
  458. device,
  459. use_fp16,
  460. use_buffer_share,
  461. add_position_ids,
  462. )
  463. inputs = {input_name: torch_tensor.cpu().numpy() for input_name, torch_tensor in inputs.items()}
  464. batched_inputs_list.append(inputs)
  465. print(f"\n--Quantize-Script-- number of batched inputs = {len(batched_inputs_list)}\n")
  466. return batched_inputs_list
  467. def is_divisible(val1, val2):
  468. return int(val2 * np.ceil(val1 / val2)) == val1
  469. class HQQWeightOnlyQuantizer:
  470. def __init__(
  471. self,
  472. config: HQQWeightOnlyQuantConfig,
  473. ):
  474. self.config = config
  475. # Proximal solver || weight - dequantize(quantize(weight))||_p^p
  476. @staticmethod
  477. def optimize_weights(
  478. tensor,
  479. scale,
  480. zero,
  481. min_max: list[int],
  482. axis: int = 0,
  483. opt_params: dict | None = None,
  484. verbose=False,
  485. ):
  486. import torch # noqa: PLC0415
  487. opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params
  488. lp_norm, beta, kappa, iters = (
  489. opt_params["lp_norm"],
  490. opt_params["beta"],
  491. opt_params["kappa"],
  492. opt_params["iters"],
  493. )
  494. dtype = torch.float16 if tensor.is_cuda else torch.float32
  495. w_f = tensor.to(dtype)
  496. scale = scale.to(dtype)
  497. zero = zero.to(dtype)
  498. def shrink_op(x, beta, p=lp_norm):
  499. if p == 1:
  500. return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
  501. else:
  502. return torch.sign(x) * torch.nn.functional.relu(
  503. torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1)
  504. )
  505. best_error = 1e4
  506. for i in range(iters):
  507. w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1])
  508. w_r = (w_q - zero) / scale
  509. w_e = shrink_op(w_f - w_r, beta)
  510. zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True)
  511. beta *= kappa
  512. current_error = float(torch.abs(w_f - w_r).mean())
  513. if verbose:
  514. print(i, np.round(current_error, 6))
  515. if current_error < best_error:
  516. best_error = current_error
  517. else:
  518. break
  519. del w_f, w_q, w_r, w_e
  520. return scale, zero
  521. @staticmethod
  522. def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits):
  523. if pack_tensor.shape[0] == ori_int_tensor.shape[0]:
  524. ori_int_tensor = ori_int_tensor.T
  525. pack_tensor = pack_tensor.T
  526. if bits in [2, 4, 8]:
  527. compress_ratio = pack_tensor.element_size() * 8 // bits
  528. for j in range(compress_ratio):
  529. pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j))
  530. else:
  531. raise NotImplementedError("Only 2,4,8 bits are supported.")
  532. # from Official implementation of Half-Quadratic Quantization (HQQ)
  533. def quantize_internal(
  534. self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1
  535. ):
  536. import torch # noqa: PLC0415
  537. weight = tensor.float()
  538. ori_shape = weight.shape
  539. pad_len = (group_size - ori_shape[axis] % group_size) % group_size
  540. if axis == 1:
  541. weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0)
  542. else:
  543. weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0)
  544. shape = weight.shape
  545. # Reshape for grouping
  546. if (group_size is not None) and channel_wise:
  547. weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1])
  548. # Get min/max values
  549. if channel_wise is False:
  550. _min, _max = weight.min(), weight.max()
  551. optimize = False
  552. else:
  553. _min = weight.min(axis=axis, keepdim=True)[0]
  554. _max = weight.max(axis=axis, keepdim=True)[0]
  555. max_v = 2**bits - 1
  556. min_v = 0
  557. min_max = [min_v, max_v]
  558. # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on.
  559. # clamp to avoid half-precision problems
  560. scale = (max_v / (_max - _min)).clamp(max=2e4)
  561. #!!!!!!!!!!!!!!!
  562. min_max_axis = _max - _min
  563. if (min_max_axis == 0).sum().item() > 0:
  564. min_max_axis[min_max_axis == 0] = max_v
  565. scale = (max_v / min_max_axis).clamp(max=2e4)
  566. zero = -_min * scale
  567. if round_zero:
  568. zero = torch.round(zero)
  569. # Fine-tune weights
  570. if optimize:
  571. scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis)
  572. # Quantize
  573. # Necessary for fake quantization backprop
  574. w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1])
  575. w_q = w_q.reshape(shape).int()
  576. scale = 1.0 / scale
  577. if axis == 1:
  578. scale = scale.reshape(shape[0], -1)
  579. zero = zero.reshape(shape[0], -1)
  580. else:
  581. scale = scale.reshape(-1, shape[-1])
  582. zero = zero.reshape(-1, shape[-1])
  583. # cleanup
  584. del weight, _min, _max
  585. return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype)
  586. def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
  587. """
  588. Target node: QOperator node: QDQ nodes:
  589. MatMul MatMulNBits DeQuantizeLinear -> MatMul
  590. Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
  591. If the node is target node with fp32 or fp16 const weight, quantize the weight to int4 and
  592. return the new nodes.
  593. If QOperator format, return the corresponding QOperator nodes.
  594. If QDQ format, return the corresdponging QDQ nodes.
  595. Gather (quantized data) + Gather (scales) + Gather (optional, zero points) -> DequantizeLinear is
  596. not supported yet because Gather does not support int4 data.
  597. """
  598. # With HQQ, zero points are in float. Current GatherBlockQuantized does not support float zero points.
  599. if node.op_type == "Gather":
  600. raise NotImplementedError("Gather quantization is not supported yet in HQQ")
  601. import torch # noqa: PLC0415
  602. logger.info(f"start to quantize {node.name} ...")
  603. input_b = node.input[1]
  604. b_pb, bs_graph = get_initializer(input_b, graph_stack)
  605. if b_pb is None:
  606. logger.info("MatMul doesn't have const weight. Skip to quantize")
  607. return [node] # only care about constant weight
  608. b_array = onnx.numpy_helper.to_array(b_pb)
  609. if len(b_array.shape) != 2:
  610. logger.info("MatMul weight is not 2D. Skip to quantize")
  611. return [node] # can only process 2-D matrix
  612. b_array_torch = torch.from_numpy(b_array)
  613. if torch.cuda.is_available():
  614. b_array_torch = b_array_torch.cuda()
  615. bits = self.config.bits
  616. quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal(
  617. b_array_torch.T, bits=bits, group_size=self.config.block_size
  618. )
  619. quant_weight_torch = quant_weight_torch.contiguous()
  620. scales_torch = scales_torch.contiguous()
  621. zero_points_torch = zero_points_torch.contiguous()
  622. packed_size = 8 // bits # number of elements packed into one byte
  623. packed_torch = torch.zeros(
  624. (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // packed_size),
  625. dtype=torch.uint8,
  626. device=quant_weight_torch.device,
  627. )
  628. self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, bits)
  629. scales = scales_torch.cpu().numpy()
  630. zero_points = zero_points_torch.cpu().numpy()
  631. # reshape to the predefined shape in MatmulNbits
  632. scales = scales.reshape(-1)
  633. zero_points = zero_points.reshape(-1)
  634. rows, cols = b_array_torch.shape
  635. block_size = self.config.block_size
  636. blob_size = block_size // packed_size
  637. k_blocks = (rows + block_size - 1) // block_size
  638. packed_torch = packed_torch.reshape(cols, k_blocks, blob_size)
  639. b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy())
  640. b_quant.name = b_pb.name + "_Q" + str(bits)
  641. for input in bs_graph.input:
  642. if input.name == input_b:
  643. bs_graph.input.remove(input)
  644. break
  645. scales_tensor = onnx.numpy_helper.from_array(scales)
  646. scales_tensor.name = b_pb.name + "_scales"
  647. bs_graph.initializer.extend([b_quant, scales_tensor])
  648. input_names = [node.input[0], b_quant.name, scales_tensor.name]
  649. zp_tensor = onnx.numpy_helper.from_array(zero_points)
  650. zp_tensor.name = b_pb.name + "_zero_points"
  651. bs_graph.initializer.extend([zp_tensor])
  652. input_names.append(zp_tensor.name)
  653. kwargs = {}
  654. rows, cols = b_array.shape
  655. kwargs["K"] = rows
  656. kwargs["N"] = cols
  657. kwargs["bits"] = bits
  658. kwargs["block_size"] = self.config.block_size
  659. matmul_q_node = onnx.helper.make_node(
  660. "MatMulNBits",
  661. inputs=input_names,
  662. outputs=[node.output[0]],
  663. name=node.name + "_Q" + str(bits) if node.name else "",
  664. domain="com.microsoft",
  665. **kwargs,
  666. )
  667. logger.info(f"complete quantization of {node.name} ...")
  668. return [matmul_q_node]
  669. def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
  670. for gid in range(len(graph_path) - 1, -1, -1):
  671. graph = graph_path[gid]
  672. for tensor in graph.initializer:
  673. if tensor.name == name:
  674. return tensor, graph
  675. return None, None
  676. # transpose int4 matrix (packed as uint8)
  677. def transpose_packed_int4_matrix(packed, rows, cols):
  678. # unpack to int4 matrix
  679. total = rows * cols
  680. high = (packed >> 4) & 0x0F
  681. low = packed & 0x0F
  682. int4_vals = np.empty(total, dtype=np.uint8)
  683. int4_vals[0::2] = low
  684. int4_vals[1::2] = high
  685. int4_matrix = int4_vals.reshape((rows, cols))
  686. # transpose int4 matrix
  687. int4_matrix_transposed = int4_matrix.T
  688. # pack to uint8
  689. flat = int4_matrix_transposed.reshape(-1)
  690. packed = ((flat[1::2] << 4) & 0xF0) | (flat[0::2] & 0x0F)
  691. return packed.astype(np.uint8)
  692. class DefaultWeightOnlyQuantizer:
  693. def __init__(self, config: DefaultWeightOnlyQuantConfig):
  694. self.config = config
  695. def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
  696. """4b/8b quantize fp32 weight to int4 using C++ kernels."""
  697. qbits = self.config.bits
  698. kpack = 8 // qbits
  699. if len(fp32weight.shape) != 2:
  700. raise ValueError("Current int4 block quantization only supports 2D tensors!")
  701. rows, cols = fp32weight.shape
  702. block_size = self.config.block_size
  703. k_blocks = (rows + block_size - 1) // block_size
  704. if self.config.quant_format == QuantFormat.QOperator:
  705. blob_size = (block_size + kpack - 1) // kpack
  706. padded_rows = k_blocks * block_size
  707. pad_len = padded_rows - rows
  708. if pad_len > 0:
  709. fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant")
  710. # block wise quantization, each block comes from a single column
  711. packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
  712. zero_point = np.zeros((cols, ((k_blocks + kpack - 1) // kpack)), dtype="uint8")
  713. scales = np.zeros((cols, k_blocks), dtype=fp32weight.dtype)
  714. if qbits == 2:
  715. quantize_matmul_2bits(
  716. packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
  717. )
  718. elif qbits == 8:
  719. quantize_matmul_8bits(
  720. packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
  721. )
  722. else:
  723. quantize_matmul_4bits(
  724. packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
  725. )
  726. else:
  727. # block size equal to rows (K) if channel wised quantize enabled
  728. block_size = rows if self.config.channel_wised_quantize else self.config.block_size
  729. k_blocks = (rows + block_size - 1) // block_size
  730. assert qbits == 4, "QDQ format only support 4 bits quantization"
  731. packed = np.zeros((rows * cols + 1) // 2, dtype="uint8")
  732. zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8")
  733. scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype)
  734. quantize_qdq_matmul_4bits(
  735. packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
  736. )
  737. return (packed, scales, zero_point)
  738. def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
  739. """
  740. Quantize weight B of MatMul node to int4 or int8.
  741. Currently only support 2D constant matrix and axis 0 blockwise quantization.
  742. """
  743. bits = self.config.bits
  744. if bits == 8:
  745. qtype = TensorProto.INT8 if self.config.is_symmetric else TensorProto.UINT8
  746. else:
  747. qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4
  748. input_b = node.input[1]
  749. b_tensor, b_graph = get_initializer(input_b, graph_stack)
  750. if b_tensor is None:
  751. logger.info("MatMul doesn't have const weight. Skip to quantize")
  752. return [node] # only care about constant weight
  753. b_ndarray = ir.from_proto(b_tensor).numpy()
  754. if len(b_ndarray.shape) != 2:
  755. logger.info("MatMul weight is not 2D. Skip to quantize")
  756. return [node] # can only process 2-D matrix
  757. bfloat16 = b_ndarray.dtype == "bfloat16"
  758. if bfloat16:
  759. b_ndarray = b_ndarray.astype(np.float32)
  760. packed, scales, zero_points = self.qbits_block_quant(b_ndarray)
  761. if bfloat16:
  762. scales = scales.astype(ml_dtypes.bfloat16)
  763. if self.config.quant_format == QuantFormat.QOperator:
  764. b_quant = ir.serde.serialize_tensor(ir.Tensor(packed, name=b_tensor.name + f"_Q{bits}"))
  765. scales_tensor = ir.serde.serialize_tensor(ir.Tensor(scales, name=b_tensor.name + "_scales"))
  766. else:
  767. b_quant = onnx.helper.make_tensor(
  768. b_tensor.name + f"_DQ_Q{bits}", qtype, b_ndarray.shape, packed.tobytes(), True
  769. )
  770. scales_tensor = ir.serde.serialize_tensor(ir.Tensor(scales, name=b_tensor.name + "_DQ_scales"))
  771. # if QDQ, CW and SYM enabled, optimize for Intel NPU, tranpose the weight to NHWC format will increase performance
  772. qdq_opt_for_intel_npu_enabled = (
  773. self.config.quant_format == QuantFormat.QDQ
  774. and self.config.channel_wised_quantize
  775. and self.config.is_symmetric
  776. )
  777. if qdq_opt_for_intel_npu_enabled:
  778. rows, cols = b_ndarray.shape
  779. packed = transpose_packed_int4_matrix(packed, rows, cols)
  780. scales = scales.reshape((cols, 1)) # (cols, 1)
  781. b_quant = onnx.helper.make_tensor(
  782. b_tensor.name + f"_DQ_Q{bits}", qtype, [cols, rows], packed.tobytes(), True
  783. )
  784. scales_tensor = ir.serde.serialize_tensor(ir.Tensor(scales, name=b_tensor.name + "_DQ_scales"))
  785. for input in b_graph.input:
  786. if input.name == input_b:
  787. b_graph.input.remove(input)
  788. break
  789. b_graph.initializer.extend([b_quant, scales_tensor])
  790. output_nodes = []
  791. if self.config.quant_format == QuantFormat.QOperator:
  792. input_names = [node.input[0], b_quant.name, scales_tensor.name]
  793. if not self.config.is_symmetric:
  794. zp_tensor = onnx.numpy_helper.from_array(zero_points, b_tensor.name + "_zero_points")
  795. input_names.append(zp_tensor.name)
  796. b_graph.initializer.extend([zp_tensor])
  797. kwargs = {}
  798. rows, cols = b_ndarray.shape
  799. kwargs["K"] = rows
  800. kwargs["N"] = cols
  801. kwargs["bits"] = bits
  802. kwargs["block_size"] = self.config.block_size
  803. # Do not output accuracy_level if it is 0 since the attribute is optional and is not supported by most EPs.
  804. if self.config.accuracy_level:
  805. kwargs["accuracy_level"] = self.config.accuracy_level
  806. matmul_qbit_node = onnx.helper.make_node(
  807. "MatMulNBits",
  808. inputs=input_names,
  809. outputs=[node.output[0]],
  810. name=node.name + f"_Q{bits}" if node.name else "",
  811. domain="com.microsoft",
  812. **kwargs,
  813. )
  814. output_nodes.append(matmul_qbit_node)
  815. else:
  816. dq_input_names = [b_quant.name, scales_tensor.name]
  817. dq_output_names = [b_quant.name + "_output"]
  818. tp_input_names = [dq_output_names[0]]
  819. tp_output_names = [dq_output_names[0] + "_transposed"]
  820. matmul_input_names = [
  821. node.input[0],
  822. tp_output_names[0] if qdq_opt_for_intel_npu_enabled else dq_output_names[0],
  823. ]
  824. matmul_output_names = [node.output[0]]
  825. if not self.config.is_symmetric:
  826. zp_tensor = onnx.helper.make_tensor(
  827. b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True
  828. )
  829. dq_input_names.append(zp_tensor.name)
  830. b_graph.initializer.extend([zp_tensor])
  831. rows, cols = b_ndarray.shape
  832. dq_kwargs = {
  833. "axis": 1 if qdq_opt_for_intel_npu_enabled else 0,
  834. "block_size": rows if self.config.channel_wised_quantize else self.config.block_size,
  835. }
  836. dq_node = onnx.helper.make_node(
  837. "DequantizeLinear",
  838. inputs=dq_input_names,
  839. outputs=dq_output_names,
  840. name=node.name + f"_DQ_Q{bits}" if node.name else "",
  841. **dq_kwargs,
  842. )
  843. matmul_node = onnx.helper.make_node(
  844. "MatMul",
  845. inputs=matmul_input_names,
  846. outputs=matmul_output_names,
  847. name=node.name + f"_matmul_Q{bits}" if node.name else "",
  848. )
  849. if qdq_opt_for_intel_npu_enabled:
  850. tp_node = onnx.helper.make_node(
  851. "Transpose",
  852. inputs=tp_input_names,
  853. outputs=tp_output_names,
  854. perm=[1, 0],
  855. )
  856. output_nodes.extend([dq_node, tp_node, matmul_node])
  857. else:
  858. output_nodes.extend([dq_node, matmul_node])
  859. return output_nodes
  860. @staticmethod
  861. def quant_slice_symmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
  862. max_val = np.max(data, axis=1, keepdims=True)
  863. min_val = np.min(data, axis=1, keepdims=True)
  864. abs_max = np.where(np.abs(max_val) > np.abs(min_val), max_val, min_val)
  865. scale = abs_max / -8.0 # if max == min, max may be clipped
  866. quantized_slice = np.where(scale == 0, 0, data / scale).round().clip(-8, 7).astype(np.int8)
  867. return quantized_slice, scale
  868. @staticmethod
  869. def quant_slice_asymmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
  870. min_val = np.minimum(data.min(axis=1, keepdims=True), 0)
  871. max_val = np.maximum(data.max(axis=1, keepdims=True), 0)
  872. scale = (max_val - min_val) / 15.0
  873. zero_point = np.where(scale == 0, 8, -min_val / scale).round().clip(0, 15).astype(np.uint8)
  874. quantized_slice = np.where(scale == 0, 8, data / scale + zero_point).round().clip(0, 15).astype(np.uint8)
  875. return quantized_slice, scale, zero_point
  876. @staticmethod
  877. def pack_int8_to_int4(data: np.ndarray) -> np.ndarray:
  878. """Pack int8 data to int4 and store in uint8 ndarray."""
  879. data_flat = data.reshape(-1)
  880. if len(data_flat) % 2 != 0:
  881. data_flat = np.append(data_flat, 0)
  882. quant_data_int4 = (data_flat[::2] & 0xF) | ((data_flat[1::2] & 0xF) << 4)
  883. return quant_data_int4.astype("uint8")
  884. @staticmethod
  885. def quantize_ndarray(
  886. data: np.ndarray,
  887. quantize_axis: int,
  888. block_size: int,
  889. is_symmetric: bool,
  890. ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
  891. """Quantize ndarray data to int4 using numpy, return (quantized data, scales, zero points)."""
  892. # Get the shape of the matrix
  893. m = 1 # dimension of the matrix before the quantize axis
  894. k = data.shape[quantize_axis] # dimension of the matrix along the quantize axis
  895. n = 1 # dimension of the matrix after the quantize axis
  896. for i, dim in enumerate(data.shape):
  897. if i < quantize_axis:
  898. m *= dim
  899. elif i > quantize_axis:
  900. n *= dim
  901. k_blocks = (k + block_size - 1) // block_size
  902. scales_shape = list(data.shape)
  903. scales_shape[quantize_axis] = k_blocks
  904. data_reshape = data.reshape((m, k, n))
  905. scales = np.zeros((m, k_blocks, n), dtype=data.dtype)
  906. if is_symmetric:
  907. quant_data_int8 = np.zeros((m, k, n), dtype="int8")
  908. else:
  909. quant_data_int8 = np.zeros((m, k, n), dtype="uint8")
  910. zero_point_int8 = np.zeros((m, k_blocks, n), dtype="uint8")
  911. # slice and quantize
  912. for i in range(0, k, block_size):
  913. end_idx = min(i + block_size, k)
  914. slice = data_reshape[:, i:end_idx, :]
  915. if is_symmetric:
  916. quantized_slice_int8, scale_slice = DefaultWeightOnlyQuantizer.quant_slice_symmetric(slice)
  917. else:
  918. quantized_slice_int8, scale_slice, zero_point_slice_int8 = (
  919. DefaultWeightOnlyQuantizer.quant_slice_asymmetric(slice)
  920. )
  921. quant_data_int8[:, i:end_idx, :] = quantized_slice_int8
  922. j = i // block_size
  923. scales[:, j : (j + 1), :] = scale_slice
  924. if not is_symmetric:
  925. zero_point_int8[:, j : (j + 1), :] = zero_point_slice_int8
  926. # pack int8 to int4
  927. quant_data_int4 = DefaultWeightOnlyQuantizer.pack_int8_to_int4(quant_data_int8)
  928. zero_point_int4 = None
  929. if not is_symmetric:
  930. zero_point_int4 = DefaultWeightOnlyQuantizer.pack_int8_to_int4(zero_point_int8)
  931. scales = scales.reshape(scales_shape)
  932. return quant_data_int4, scales, zero_point_int4
  933. def quantize_gather(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
  934. """Quantize weight data of Gather node to int4."""
  935. assert self.config.quant_format == QuantFormat.QOperator, "Gather only supports QOperator format currently."
  936. qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4
  937. data_arg = node.input[0]
  938. data_tensorproto, data_graphproto = get_initializer(data_arg, graph_stack)
  939. if data_tensorproto is None:
  940. logger.info("Gather doesn't have const weight. Skip quantization.")
  941. return [node] # only care about constant weight
  942. data_ndarray = onnx.numpy_helper.to_array(data_tensorproto)
  943. data_rank = len(data_ndarray.shape)
  944. quantize_axis = self.config.quant_axes.get("Gather", 1)
  945. block_size = self.config.block_size
  946. assert quantize_axis < data_rank and quantize_axis >= -data_rank, "Invalid quantize axis for Gather node."
  947. assert block_size >= 16 and ((block_size - 1) & block_size == 0), "Invalid block size for Gather node."
  948. quantize_axis = (quantize_axis + data_rank) % data_rank
  949. quantized_data, scales, zero_points = self.quantize_ndarray(
  950. data_ndarray, quantize_axis, block_size, self.config.is_symmetric
  951. )
  952. for input in data_graphproto.input:
  953. if input.name == data_arg:
  954. data_graphproto.input.remove(input)
  955. break
  956. quantized_data_tensorproto = onnx.helper.make_tensor(
  957. data_tensorproto.name + "_Q4", qtype, data_ndarray.shape, quantized_data.tobytes(), True
  958. )
  959. scales_tensorproto = onnx.numpy_helper.from_array(scales, data_tensorproto.name + "_scales")
  960. input_names = [quantized_data_tensorproto.name, node.input[1], scales_tensorproto.name]
  961. data_graphproto.initializer.extend([quantized_data_tensorproto, scales_tensorproto])
  962. if not self.config.is_symmetric:
  963. zp_tensorproto = onnx.helper.make_tensor(
  964. data_tensorproto.name + "_zero_points", qtype, scales.shape, zero_points.tobytes(), True
  965. )
  966. input_names.append(zp_tensorproto.name)
  967. data_graphproto.initializer.extend([zp_tensorproto])
  968. try:
  969. gather_axis = onnx.helper.get_node_attr_value(node, "axis")
  970. except ValueError:
  971. gather_axis = 0
  972. kwargs = {
  973. "gather_axis": gather_axis,
  974. "quantize_axis": quantize_axis,
  975. "block_size": block_size,
  976. }
  977. gather_q4_node = onnx.helper.make_node(
  978. "GatherBlockQuantized",
  979. inputs=input_names,
  980. outputs=[node.output[0]],
  981. name=node.name + "_Q4" if node.name else "",
  982. domain="com.microsoft",
  983. **kwargs,
  984. )
  985. return [gather_q4_node]
  986. def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
  987. """
  988. Target node: QOperator node: QDQ nodes:
  989. MatMul MatMulNBits DeQuantizeLinear -> MatMul
  990. Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
  991. If the node is target node with fp32 or fp16 const weight, quantize the weight to int4 and
  992. return the new nodes.
  993. If QOperator format, return the corresponding QOperator nodes.
  994. If QDQ format, return the corresdponging QDQ nodes.
  995. Gather (quantized data) + Gather (scales) + Gather (optional, zero points) -> DequantizeLinear is
  996. not supported yet because Gather does not support int4 data.
  997. """
  998. logger.info(f"start to quantize {node.name} ...")
  999. bits = self.config.bits
  1000. if node.op_type == "MatMul":
  1001. if bits == 8 and self.config.quant_format == QuantFormat.QDQ:
  1002. logger.error("MatMul only supports QOperator format for 8 bits quantization.")
  1003. return [node]
  1004. results = self.quantize_matmul(node, graph_stack)
  1005. elif node.op_type == "Gather":
  1006. if self.config.bits != 4:
  1007. logger.error("Gather only supports 4 bits quantization.")
  1008. return [node]
  1009. results = self.quantize_gather(node, graph_stack)
  1010. else:
  1011. logger.error(f"Unsupported operator {node.op_type} for weight only quantization. Skip quantization.")
  1012. return [node]
  1013. logger.info(f"complete quantization of {node.name} with {self.config.bits} bits ...")
  1014. return results
  1015. class NVAWQWeightOnlyQuantizer:
  1016. def __init__(
  1017. self,
  1018. config: NVAWQWeightOnlyQuantConfig,
  1019. ):
  1020. self.config = config
  1021. def quantize_awq(self, model: ModelProto | str) -> ModelProto:
  1022. """
  1023. Perform nvidia_awq quantization using ModelOpt's int4 quantize function.
  1024. Args:
  1025. model (ModelProto): The ONNX model to quantize.
  1026. Returns:
  1027. ModelProto: The quantized ONNX model.
  1028. """
  1029. try:
  1030. from modelopt.onnx.quantization.int4 import quantize as quantize_int4 # noqa: PLC0415
  1031. except ImportError:
  1032. print(
  1033. "Please ensure that the 'modelopt' package is installed. Please install it using pip install nvidia_modelopt."
  1034. )
  1035. raise ImportError(
  1036. "modelopt is not installed. Please install it using pip install nvidia_modelopt. Exiting."
  1037. ) from None
  1038. logger.info("Starting nvidia_awq quantization...")
  1039. # Prepare calibration inputs
  1040. calib_inputs = self.config.calibration_data_reader
  1041. # Perform quantization using ModelOpt's int4 quantize function
  1042. quantized_model = quantize_int4(
  1043. model,
  1044. calibration_method=self.config.calibration_method,
  1045. calibration_data_reader=calib_inputs,
  1046. )
  1047. logger.info("Completed nvidia_awq quantization.")
  1048. return quantized_model
  1049. class MatMulNBitsQuantizer:
  1050. """
  1051. Target node: QOperator node: QDQ nodes:
  1052. MatMul MatMulNBits DeQuantizeLinear -> MatMul
  1053. Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
  1054. Perform 2/4/8 bits quantization of constant weights for target nodes.
  1055. If algo_config.quant_format is QOperator:
  1056. - nodes are replaced by the corresponding QOperator nodes.
  1057. - quantized weights are stored in the contrib ops.
  1058. If algo_config.quant_format is QDQ:
  1059. - the quantized weight is stored in a standard onnx node. For MatMul, it is DequantizeLinear. For Gather,
  1060. it is the three Gathers, one for quantized data, one for scales and one for optional zero points.
  1061. - The nodes are replaced by the corresponding QDQ nodes.
  1062. - currently Gather is not supported in QDQ because Gather does not support int4 yet.
  1063. Note:
  1064. - for quantized gather, the memory usage of "DequantizeLinear + Gather" is the same as the original Gather
  1065. during runtime. Therefor it is not recommended.
  1066. - when a node is in nodes_to_exclude, and the node configuration in algo_config.customized_weight_config will be ignored.
  1067. """
  1068. def __init__(
  1069. self,
  1070. model: ModelProto | str,
  1071. bits: int = 4, # default to 4bit
  1072. block_size: int = 128,
  1073. is_symmetric: bool = False,
  1074. accuracy_level: int | None = None,
  1075. nodes_to_exclude=None,
  1076. nodes_to_include: list[str] | None = None,
  1077. quant_format=QuantFormat.QOperator,
  1078. op_types_to_quantize: tuple[str, ...] | None = None,
  1079. quant_axes: tuple[tuple[str, int], ...] | None = None,
  1080. channel_wised_quantize: bool = False,
  1081. algo_config: WeightOnlyQuantConfig | None = None,
  1082. ):
  1083. if nodes_to_exclude is None:
  1084. nodes_to_exclude = []
  1085. self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
  1086. self.model_path = model if isinstance(model, str) else None
  1087. self.bits = bits
  1088. self.block_size = block_size
  1089. self.is_symmetric = is_symmetric
  1090. self.accuracy_level = accuracy_level
  1091. self.nodes_to_exclude = set(nodes_to_exclude)
  1092. self.nodes_to_include = set(nodes_to_include) if nodes_to_include else None
  1093. self.node_quantizer = None
  1094. if algo_config is None:
  1095. algo_config = DefaultWeightOnlyQuantConfig(
  1096. block_size=block_size,
  1097. is_symmetric=is_symmetric,
  1098. accuracy_level=accuracy_level,
  1099. quant_format=quant_format,
  1100. op_types_to_quantize=op_types_to_quantize,
  1101. quant_axes=quant_axes,
  1102. bits=bits,
  1103. channel_wised_quantize=channel_wised_quantize,
  1104. )
  1105. self.algo_config = algo_config
  1106. if hasattr(self.algo_config, "bits"):
  1107. assert self.algo_config.bits in [2, 4, 8], "Only support 2, 4 or 8 bits quantization"
  1108. if algo_config.algorithm == "HQQ":
  1109. self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config)
  1110. elif algo_config.algorithm == "DEFAULT":
  1111. self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config)
  1112. elif algo_config.algorithm == "nvidia_awq":
  1113. self.node_quantizer = NVAWQWeightOnlyQuantizer(self.algo_config)
  1114. def _process_subgraph(self, graph_stack: list[GraphProto]):
  1115. new_nodes = []
  1116. graph = graph_stack[-1]
  1117. for node in graph.node:
  1118. graph_attrs = [
  1119. attr
  1120. for attr in node.attribute
  1121. if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
  1122. ]
  1123. if graph_attrs:
  1124. kwargs = {}
  1125. for attr in node.attribute:
  1126. if attr.type == onnx.AttributeProto.GRAPH:
  1127. # recursive call to take care of sub-graph
  1128. graph_stack.append(attr.g)
  1129. kv = {attr.name: self._process_subgraph(graph_stack)}
  1130. elif attr.type == onnx.AttributeProto.GRAPHS:
  1131. value = []
  1132. for subgraph in attr.graphs:
  1133. # recursive call to take care of sub-graph
  1134. graph_stack.append(subgraph)
  1135. value.extend([self._process_subgraph(graph_stack)])
  1136. kv = {attr.name: value}
  1137. else:
  1138. kv = attribute_to_kwarg(attr)
  1139. kwargs.update(kv)
  1140. node = onnx.helper.make_node( # noqa: PLW2901
  1141. node.op_type, node.input, node.output, name=node.name, **kwargs
  1142. )
  1143. out_nodes = []
  1144. if node.name in self.nodes_to_exclude:
  1145. logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
  1146. out_nodes = [node]
  1147. elif (self.nodes_to_include and node.name in self.nodes_to_include) or (
  1148. node.op_type in self.algo_config.op_types_to_quantize
  1149. ):
  1150. out_nodes = self.node_quantizer.quantize(node, graph_stack)
  1151. else:
  1152. logger.info(f"skip to quantize {node.name} ...")
  1153. out_nodes = [node]
  1154. new_nodes.extend(out_nodes)
  1155. graph.ClearField("node")
  1156. graph.node.extend(new_nodes)
  1157. graph_stack.pop()
  1158. return graph
  1159. def _generate_q4_node_config(self):
  1160. """Generate weight only quant configuration for nodes."""
  1161. q4_node_config = {}
  1162. for node in self.model.model.graph.node:
  1163. if node.op_type in ["MatMul"]:
  1164. if not all(self.model.get_initializer(i) is None for i in node.input):
  1165. template_config_q4 = {
  1166. "bits": 4,
  1167. "group_size": self.block_size,
  1168. "scheme": "sym" if self.is_symmetric else "asym",
  1169. }
  1170. if (
  1171. self.algo_config.customized_weight_config
  1172. and node.name in self.algo_config.customized_weight_config
  1173. ):
  1174. for key, value in self.algo_config.customized_weight_config[node.name].items():
  1175. if key in template_config_q4:
  1176. template_config_q4[key] = value
  1177. q4_node_config[node.name] = template_config_q4
  1178. return q4_node_config
  1179. def int4_quant_algo(self):
  1180. """4b quantize a model with RTN or GPTQ algorithm. Please refer to
  1181. https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md
  1182. for more details on weight only quantization using Intel® Neural Compressor.
  1183. """
  1184. def inc_dataloader():
  1185. data_reader = copy.deepcopy(self.algo_config.calibration_data_reader)
  1186. for data in data_reader:
  1187. yield data, None
  1188. kwargs = {}
  1189. if self.accuracy_level is not None:
  1190. kwargs["accuracy_level"] = self.accuracy_level
  1191. weight_only_node_config = self._generate_q4_node_config()
  1192. algorithm = self.algo_config.algorithm
  1193. logger.info(f"start to quantize model with {algorithm} algorithm...")
  1194. if algorithm in ["RTN", "k_quant"]:
  1195. kwargs["ratios"] = self.algo_config.ratios
  1196. kwargs["algorithm"] = algorithm
  1197. """
  1198. We uses fp32 to represent the node that skip quantization, it does not mean this node is fp32 type though.
  1199. """
  1200. for n in self.nodes_to_exclude:
  1201. weight_only_node_config[n] = "fp32"
  1202. self.model = rtn_quantize(
  1203. model=self.model_path if self.model_path is not None else self.model.model,
  1204. weight_config=weight_only_node_config,
  1205. **kwargs,
  1206. )
  1207. elif algorithm == "GPTQ":
  1208. kwargs["percdamp"] = self.algo_config.percdamp
  1209. kwargs["blocksize"] = self.algo_config.block_size
  1210. kwargs["actorder"] = self.algo_config.actorder
  1211. kwargs["mse"] = self.algo_config.mse
  1212. kwargs["perchannel"] = self.algo_config.perchannel
  1213. kwargs["n_samples"] = -1
  1214. dataloader = inc_dataloader()
  1215. self.model = gptq_quantize(
  1216. model=self.model_path if self.model_path is not None else self.model.model,
  1217. weight_config=weight_only_node_config,
  1218. dataloader=dataloader,
  1219. **kwargs,
  1220. )
  1221. logger.info(f"complete quantization of model with {algorithm} algorithm.")
  1222. def process(self):
  1223. if self.algo_config.algorithm in ["HQQ", "DEFAULT"]:
  1224. # use a stack to keep track of sub-graphs
  1225. graph_stack = [self.model.graph()]
  1226. # Update domain opset
  1227. if self.algo_config.quant_format == QuantFormat.QOperator:
  1228. self.model.set_opset_import("com.microsoft", 1)
  1229. if self.algo_config.quant_format == QuantFormat.QDQ or "Gather" in self.algo_config.op_types_to_quantize:
  1230. opset_import = self.model.opset_import()
  1231. for opset in opset_import:
  1232. if opset.domain in [None, "ai.onnx", ""] and opset.version < 21:
  1233. logger.warning(
  1234. "The opset of the input model is under 21 and doesn't support int4 data type. "
  1235. "Force to update it to opset 21, but the generated model may not be a valid model."
  1236. )
  1237. self.model.set_opset_import(opset.domain, 21)
  1238. self._process_subgraph(graph_stack)
  1239. self.model.clean_initializers()
  1240. elif self.algo_config.algorithm == "nvidia_awq":
  1241. # Handle nvidia_awq quantization
  1242. logger.info("Processing nvidia_awq quantization...")
  1243. self.model = self.node_quantizer.quantize_awq(
  1244. self.model.model if self.model_path is None else self.model_path
  1245. )
  1246. logger.info("Completed nvidia_awq quantization.")
  1247. self.model = ONNXModel(self.model) # Ensure the model is wrapped back into ONNXModel
  1248. self.model.clean_initializers()
  1249. else:
  1250. # RTN or GPTQ weight-only quantize algorithm
  1251. self.int4_quant_algo()
  1252. def ort_convert_str_to_bool(value):
  1253. return value.lower() in ("true", "1")
  1254. # Custom function to parse str:int pairs
  1255. def parse_key_value_pair(s):
  1256. key, value = s.split(":")
  1257. return key, int(value)
  1258. def parse_args():
  1259. parser = argparse.ArgumentParser(
  1260. description="""Blockwise int4 quantization for MatMul 2D weight matrices.
  1261. A weight matrix is partitioned into into blocks, where each block is a
  1262. continguous subset inside each column. Each block is quantized into a
  1263. set of 4b integers with a scaling factor and an optional offset.
  1264. """
  1265. )
  1266. parser.add_argument("--input_model", required=True, help="Path to the input model file")
  1267. parser.add_argument("--output_model", required=True, help="Path to the output model file")
  1268. parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization")
  1269. parser.add_argument(
  1270. "--quant_method",
  1271. default="default",
  1272. type=str,
  1273. choices=["default", "hqq", "rtn", "k_quant", "gptq", "nvidia_awq"],
  1274. help="the algorithm used to quantize weight, \nrtn and gptq leverage Intel® Neural Compressor",
  1275. )
  1276. parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight")
  1277. parser.add_argument(
  1278. "--symmetric",
  1279. required=False,
  1280. default=True,
  1281. const=True,
  1282. nargs="?",
  1283. type=ort_convert_str_to_bool,
  1284. choices=[True, False],
  1285. help="Indicate whether to quantize the model symmetrically, symmetric is not supported by hqq",
  1286. )
  1287. parser.add_argument(
  1288. "--accuracy_level",
  1289. required=False,
  1290. type=int,
  1291. help="Accuracy level of the 4-bit quantized MatMul computation. "
  1292. "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
  1293. "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
  1294. )
  1295. parser.add_argument("-v", "--verbose", required=False, action="store_true")
  1296. parser.set_defaults(verbose=False)
  1297. parser.add_argument(
  1298. "--nodes_to_exclude",
  1299. nargs="+",
  1300. type=str,
  1301. required=False,
  1302. default=[],
  1303. help="Specify the nodes to be excluded from quantization with node names",
  1304. )
  1305. parser.add_argument(
  1306. "--nodes_to_include",
  1307. nargs="+",
  1308. type=str,
  1309. required=False,
  1310. help="Specify the specific nodes to be included from quantization with node names",
  1311. )
  1312. parser.add_argument(
  1313. "--quant_format",
  1314. default="QOperator",
  1315. type=str,
  1316. choices=["QOperator", "QDQ"],
  1317. help="QuantFormat {QOperator, QDQ}"
  1318. "QOperator format quantizes the model with quantized operators directly."
  1319. "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.",
  1320. )
  1321. parser.add_argument(
  1322. "--op_types_to_quantize",
  1323. type=str,
  1324. nargs="+",
  1325. choices=["MatMul", "Gather"],
  1326. help="op_types_to_quantize {MatMul, Gather}. Operators to quantize. Default is MatMul.",
  1327. )
  1328. parser.add_argument(
  1329. "--quant_axes",
  1330. type=parse_key_value_pair,
  1331. nargs="+",
  1332. required=False,
  1333. help="Key-value pairs in op_type:axis_to_quantize separated by space."
  1334. "Specify the axis to quantize for an op. Default {MatMul:0, Gather:1}"
  1335. "Example: --quant_axes MatMul:0 Gather:1",
  1336. )
  1337. # Group arguments specific to nvidia_awq
  1338. nv_awq_config = parser.add_argument_group("nvidia_awq", "Arguments specific to nvidia_awq quantization")
  1339. nv_awq_config.add_argument(
  1340. "--calib_dataset_name",
  1341. type=str,
  1342. default="cnn",
  1343. help="Name of the calibration dataset for nvidia_awq.",
  1344. )
  1345. nv_awq_config.add_argument(
  1346. "--tokenizer_dir",
  1347. type=str,
  1348. required=False,
  1349. help="Path of the tokenizer dir.",
  1350. )
  1351. nv_awq_config.add_argument(
  1352. "--calibration_method",
  1353. type=str,
  1354. required=False,
  1355. choices=["awq", "awq_clip"],
  1356. help="Support two options, awq implementation and weight clipping.",
  1357. )
  1358. nv_awq_config.add_argument(
  1359. "--cache_dir",
  1360. type=str,
  1361. default="./cache",
  1362. help="Cache directory for calibration data.",
  1363. )
  1364. return parser.parse_args()
  1365. if __name__ == "__main__":
  1366. args = parse_args()
  1367. if args.verbose:
  1368. logger.setLevel(logging.DEBUG)
  1369. input_model_path = args.input_model
  1370. output_model_path = args.output_model
  1371. quant_format = QuantFormat[args.quant_format]
  1372. op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else ("MatMul",)
  1373. quant_axes = tuple(args.quant_axes) if args.quant_axes else None
  1374. if os.path.exists(output_model_path):
  1375. logger.error(f"file {output_model_path} already exists")
  1376. raise Exception(f"file {output_model_path} already exists")
  1377. if args.symmetric and args.quant_method == "hqq":
  1378. logger.warning("symmetric is not supportted by hqq, will force to symmetric=False")
  1379. args.symmetric = False
  1380. model = onnx.load(input_model_path)
  1381. if args.quant_method == "hqq":
  1382. quant_config = HQQWeightOnlyQuantConfig(
  1383. block_size=args.block_size, bits=args.bits, op_types_to_quantize=op_types_to_quantize, quant_axes=quant_axes
  1384. )
  1385. elif args.quant_method == "default":
  1386. quant_config = DefaultWeightOnlyQuantConfig(
  1387. block_size=args.block_size,
  1388. is_symmetric=args.symmetric,
  1389. accuracy_level=args.accuracy_level,
  1390. quant_format=quant_format,
  1391. op_types_to_quantize=op_types_to_quantize,
  1392. quant_axes=quant_axes,
  1393. bits=args.bits,
  1394. )
  1395. elif args.quant_method == "rtn":
  1396. quant_config = RTNWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize)
  1397. elif args.quant_method == "k_quant":
  1398. quant_config = KQuantWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize)
  1399. elif args.quant_method == "gptq":
  1400. quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, op_types_to_quantize=op_types_to_quantize)
  1401. elif args.quant_method == "nvidia_awq":
  1402. if quant_format == QuantFormat.QOperator:
  1403. logger.warning("QOperator is not applicable to nvidia_awq. overriding the value to QDQ")
  1404. quant_format = QuantFormat.QDQ
  1405. model = input_model_path
  1406. if args.calibration_method is not None:
  1407. if args.calibration_method == "awq":
  1408. calibration_method = "awq_lite"
  1409. else:
  1410. calibration_method = "awq_clip"
  1411. else:
  1412. calibration_method = "awq_lite"
  1413. quant_config = NVAWQWeightOnlyQuantConfig(
  1414. dataset_name=args.calib_dataset_name,
  1415. tokenizer_dir=args.tokenizer_dir,
  1416. cache_dir=args.cache_dir,
  1417. calibration_method=calibration_method,
  1418. )
  1419. else:
  1420. raise ValueError(f"Unsupported quantization method: {args.quant_method}")
  1421. quant = MatMulNBitsQuantizer(
  1422. model=model,
  1423. bits=args.bits,
  1424. accuracy_level=args.accuracy_level,
  1425. nodes_to_exclude=args.nodes_to_exclude,
  1426. nodes_to_include=args.nodes_to_include,
  1427. algo_config=quant_config,
  1428. )
  1429. quant.process()
  1430. quant.model.save_model_to_file(output_model_path, True)