quantizer_sinq.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. from typing import TYPE_CHECKING
  16. from ..utils import is_torch_available, logging
  17. from ..utils.quantization_config import SinqConfig
  18. from .base import HfQuantizer
  19. from .quantizers_utils import get_module_from_name
  20. if is_torch_available():
  21. import torch
  22. if TYPE_CHECKING:
  23. from ..modeling_utils import PreTrainedModel
  24. logger = logging.get_logger(__name__)
  25. class SinqHfQuantizer(HfQuantizer):
  26. """
  27. HF v5 quantizer for SINQ.
  28. Modes:
  29. - method="sinq" (default):
  30. * weight-only SINQ
  31. * param-level ConversionOps (`SinqQuantize`) during load for pure language models
  32. (each Linear.weight is turned into a SINQLinear module)
  33. * module-level quantization after load for multimodal models
  34. - method="asinq":
  35. * A-SINQ (activation-aware) SINQ quantization
  36. """
  37. requires_parameters_quantization: bool = True
  38. quantization_config: SinqConfig
  39. def __init__(self, quantization_config: SinqConfig, **kwargs):
  40. super().__init__(quantization_config, **kwargs)
  41. self._normalized_device_str: str | None = None
  42. self._do_param_level_sinq: bool = False
  43. def is_serializable(self) -> bool:
  44. return True
  45. @property
  46. def is_trainable(self) -> bool:
  47. return True
  48. def update_device_map(self, device_map):
  49. if device_map is None:
  50. if torch.cuda.is_available():
  51. device_map = {"": torch.cuda.current_device()}
  52. else:
  53. device_map = {"": "cpu"}
  54. logger.info(
  55. "The device_map was not initialized. "
  56. f"Setting device_map to {device_map}. "
  57. "If you want to use the model for inference, please set device_map='auto'"
  58. )
  59. return device_map
  60. def update_dtype(self, dtype: torch.dtype) -> torch.dtype:
  61. if dtype is None:
  62. dtype = torch.bfloat16
  63. self.dtype = dtype
  64. return dtype
  65. def validate_environment(self, *args, **kwargs) -> None:
  66. from ..utils import is_sinq_available
  67. if not is_sinq_available():
  68. raise ImportError("The 'sinq' package is not installed. Please install it with: pip install sinq")
  69. if not torch.cuda.is_available():
  70. logger.warning(
  71. "No CUDA device is available. Quantization and inference will run on the CPU. Please note that this will significantly slow down inference speed and increase quantization time."
  72. )
  73. device_map = kwargs.get("device_map")
  74. if isinstance(device_map, dict):
  75. device_map_values = set(device_map.values())
  76. if len(device_map_values) > 1:
  77. raise RuntimeError(
  78. "SinqHfQuantizer: multi-GPU device_map detected, but SINQ currently supports only a single CUDA "
  79. f"device. Got {sorted(device_map_values)}. Please use device_map=None."
  80. )
  81. if self.quantization_config.method == "asinq" and not self.pre_quantized:
  82. raise ValueError(
  83. "You are using `method='asinq'` in the quantization config. Right now the calibrated version of SINQ"
  84. " is not supported in Hugging Face, please refer and use the official SINQ repository "
  85. "`to quantize a model with this method. "
  86. )
  87. def _build_sinq_quant_dict(self, cfg: SinqConfig) -> dict:
  88. """
  89. Build the dict that SINQLinear expects as quant_config.
  90. """
  91. from sinq.sinqlinear_hf import sinq_base_quant_config as sinq_base_quant_config_fn
  92. method = cfg.method
  93. return sinq_base_quant_config_fn(
  94. nbits=int(cfg.nbits),
  95. group_size=int(cfg.group_size) if cfg.group_size is not None else None,
  96. quant_zero=False,
  97. quant_scale=False,
  98. view_as_float=False,
  99. axis=1,
  100. tiling_mode=str(cfg.tiling_mode),
  101. method=method,
  102. )
  103. def param_needs_quantization(self, model: PreTrainedModel, param_name: str, **kwargs) -> bool:
  104. """
  105. Called per-parameter to decide whether to run `SinqQuantize` on it.
  106. - If `self.pre_quantized`, we do *not* quantize again (handled by SinqDeserialize instead).
  107. - For method="asinq": return False (ASINQ is not supported in Hugging Face).
  108. - For method="sinq": True only for SINQLinear.weight not in modules_to_not_convert.
  109. Note: After _process_model_before_weight_loading(), the modules are already SINQLinear,
  110. not nn.Linear. We check for SINQLinear modules that are not yet quantized (ready=False).
  111. """
  112. from sinq.sinqlinear_hf import SINQLinear
  113. if self.pre_quantized:
  114. return False
  115. if self.quantization_config.method == "asinq":
  116. return False
  117. # SINQ param-level only if deemed safe
  118. if not self._do_param_level_sinq:
  119. return False
  120. module, tensor_name = get_module_from_name(model, param_name)
  121. if tensor_name != "weight":
  122. return False
  123. # Check if it's an unquantized SINQLinear
  124. is_sinq = isinstance(module, SINQLinear)
  125. is_ready = getattr(module, "ready", True)
  126. result = is_sinq and not is_ready
  127. return result
  128. def get_quantize_ops(self):
  129. """
  130. Return the ConversionOps used for param-level quantization (Sinq).
  131. The actual SINQLinear construction is in integrations/sinq.py.
  132. """
  133. from ..integrations.sinq import SinqQuantize
  134. return SinqQuantize(self)
  135. def get_weight_conversions(self):
  136. """
  137. If `pre_quantized=True`, interpret a checkpoint produced by SINQLinear.state_dict:
  138. <prefix>.W_q
  139. <prefix>.bias
  140. <prefix>.meta
  141. via a WeightConverter + SinqDeserialize so that we reconstruct a SINQLinear
  142. module instead of a plain nn.Linear.
  143. """
  144. from ..core_model_loading import WeightConverter
  145. if self.pre_quantized:
  146. from ..integrations.sinq import SinqDeserialize
  147. return [
  148. WeightConverter(
  149. source_patterns=[
  150. ".W_q",
  151. ".meta",
  152. ".bias",
  153. ],
  154. target_patterns=[".weight"],
  155. operations=[SinqDeserialize(self)],
  156. )
  157. ]
  158. return []
  159. def _process_model_before_weight_loading(
  160. self,
  161. model: PreTrainedModel,
  162. device_map,
  163. keep_in_fp32_modules: list[str] | None = None,
  164. **kwargs,
  165. ):
  166. """
  167. Called on meta-initialized model, before loading any weights.
  168. For SINQ, we replace nn.Linear modules with empty SINQLinear modules here.
  169. The actual quantization happens later in SinqQuantize.convert() when weights are loaded.
  170. """
  171. from ..integrations.sinq import replace_with_sinq_linear
  172. self.modules_to_not_convert = self.get_modules_to_not_convert(
  173. model, (self.quantization_config.modules_to_not_convert or []), keep_in_fp32_modules
  174. )
  175. # Enable param-level quantization for SINQ method
  176. self._do_param_level_sinq = self.quantization_config.method == "sinq" and not self.pre_quantized
  177. sinq_quant_dict = None if self.pre_quantized else self._build_sinq_quant_dict(self.quantization_config)
  178. # Extract device from device_map (guaranteed to be set by update_device_map)
  179. if isinstance(device_map, dict):
  180. first_device = next(iter(device_map.values()), 0)
  181. if isinstance(first_device, int):
  182. device_str = f"cuda:{first_device}"
  183. else:
  184. device_str = str(first_device)
  185. else:
  186. device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
  187. model = replace_with_sinq_linear(
  188. model,
  189. modules_to_not_convert=self.modules_to_not_convert,
  190. quant_config=sinq_quant_dict,
  191. compute_dtype=self.dtype,
  192. device=device_str,
  193. pre_quantized=self.pre_quantized,
  194. )
  195. def _process_model_after_weight_loading(
  196. self,
  197. model: PreTrainedModel,
  198. **kwargs,
  199. ):
  200. """
  201. Called after *all* weights have been loaded.
  202. For SINQ:
  203. 1. Move non-SINQLinear modules to GPU (embeddings, norms, lm_head, etc.)
  204. - SINQLinear modules already have GemLite buffers on GPU
  205. - We skip moving SINQLinear's W_q/meta to avoid memory duplication
  206. 2. Patch HF save/load methods for SINQ serialization
  207. """
  208. from sinq.hf_io import patch_hf_pretrained_io
  209. # Patch HF save/load methods for SINQ serialization
  210. patch_hf_pretrained_io()
  211. return model