sinq.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. from typing import Any
  16. from transformers.utils import is_torch_available, logging
  17. from ..core_model_loading import ConversionOps
  18. from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
  19. logger = logging.get_logger(__name__)
  20. if is_torch_available():
  21. import torch
  22. import torch.nn as nn
  23. def replace_with_sinq_linear(
  24. model: torch.nn.Module,
  25. modules_to_not_convert: list[str] | None = None,
  26. quant_config: dict | None = None,
  27. compute_dtype: torch.dtype = None,
  28. device: str = "cuda:0",
  29. pre_quantized: bool = False,
  30. ) -> torch.nn.Module:
  31. """
  32. Replace nn.Linear modules with empty SINQLinear modules.
  33. Args:
  34. model: The model to modify
  35. modules_to_not_convert: List of module names to skip
  36. quant_config: SINQ quantization config dict (None for pre-quantized models)
  37. compute_dtype: Computation dtype for the quantized layers
  38. device: Device string for the quantized layers
  39. pre_quantized: Whether loading a pre-quantized checkpoint
  40. Returns:
  41. The modified model with SINQLinear modules
  42. """
  43. from sinq.sinqlinear_hf import SINQLinear
  44. if modules_to_not_convert is None:
  45. modules_to_not_convert = []
  46. for full_name, module in list(model.named_modules()):
  47. if not isinstance(module, nn.Linear):
  48. continue
  49. if not should_convert_module(full_name, modules_to_not_convert):
  50. continue
  51. parent_path, _, child_name = full_name.rpartition(".")
  52. parent = model.get_submodule(parent_path) if parent_path else model
  53. sinq_layer = SINQLinear(
  54. in_features=module.in_features if not pre_quantized else None,
  55. out_features=module.out_features if not pre_quantized else None,
  56. bias=(module.bias is not None) if not pre_quantized else False,
  57. quant_config=quant_config,
  58. compute_dtype=compute_dtype,
  59. device=device,
  60. use_unpack_kernel=True,
  61. )
  62. setattr(parent, child_name, sinq_layer)
  63. return model
  64. class SinqQuantize(ConversionOps):
  65. """
  66. Param-level ConversionOp for SINQ (from FP weights).
  67. At load time, for each `Linear.weight` that should be quantized:
  68. - The SINQLinear module already exists (created in _process_model_before_weight_loading)
  69. - We just call quantize() on it with the loaded weight tensor
  70. """
  71. def __init__(self, hf_quantizer):
  72. self.hf_quantizer = hf_quantizer
  73. def convert(
  74. self,
  75. input_dict: dict[str, Any],
  76. model: torch.nn.Module | None = None,
  77. full_layer_name: str | None = None,
  78. missing_keys=None,
  79. **kwargs,
  80. ) -> dict[str, torch.Tensor]:
  81. _, values = next(iter(input_dict.items()))
  82. weight_tensor = values[0] if isinstance(values, list) else values
  83. module, tensor_name = get_module_from_name(model, full_layer_name)
  84. module.quantize(weight_tensor)
  85. if missing_keys is not None:
  86. missing_keys.discard(full_layer_name)
  87. module._is_hf_initialized = True
  88. return {}
  89. class SinqDeserialize(ConversionOps):
  90. """
  91. ConversionOp for loading *pre-quantized* SINQ checkpoints.
  92. Checkpoint layout (what `SINQLinear.state_dict` produces) is, per module:
  93. <prefix>.W_q
  94. <prefix>.bias
  95. <prefix>.meta
  96. WeightConverter in the quantizer is configured so that:
  97. - we group ".W_q", ".meta", ".bias" as input_dict
  98. - conceptually treat them as belonging to "<prefix>.weight"
  99. - and call this SinqDeserialize.convert to load the state into the existing SINQLinear.
  100. The returned dict is {} because we load directly into the module.
  101. """
  102. def __init__(self, hf_quantizer):
  103. self.hf_quantizer = hf_quantizer
  104. def convert(
  105. self,
  106. input_dict: dict[str, Any],
  107. model: torch.nn.Module | None = None,
  108. full_layer_name: str | None = None,
  109. **kwargs,
  110. ) -> dict[str, torch.Tensor]:
  111. for k, v in list(input_dict.items()):
  112. if isinstance(v, list):
  113. input_dict[k] = v[0]
  114. W_q = input_dict.get(".W_q")
  115. meta = input_dict.get(".meta")
  116. bias = input_dict.get(".bias")
  117. # Fallback path: if W_q or meta is missing, this is not a valid SINQ checkpoint.
  118. # Return the tensor as-is so standard HF weight loading can handle it.
  119. if W_q is None or meta is None:
  120. v = next(iter(input_dict.values()))
  121. if isinstance(v, list):
  122. v = v[0]
  123. return {full_layer_name: v}
  124. module, _ = get_module_from_name(model, full_layer_name)
  125. state = {
  126. "W_q": W_q,
  127. "meta": meta,
  128. }
  129. if bias is not None:
  130. state["bias"] = bias
  131. module.load_state_dict(state)
  132. module._is_hf_initialized = True
  133. return {}