quantizer_hqq.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from ..integrations import prepare_for_hqq_linear
  16. from ..utils import is_hqq_available, is_torch_available, logging
  17. from .base import HfQuantizer
  18. from .quantizers_utils import get_module_from_name
  19. if TYPE_CHECKING:
  20. from ..modeling_utils import PreTrainedModel
  21. from ..utils.quantization_config import HqqConfig
  22. if is_torch_available():
  23. import torch
  24. if is_hqq_available():
  25. from hqq.core.quantize import HQQLinear
  26. # This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute,
  27. # but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors,
  28. # we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device.
  29. @property
  30. def weight(self):
  31. return torch.empty(0, dtype=self.compute_dtype, device=self.device)
  32. HQQLinear.weight = weight
  33. logger = logging.get_logger(__name__)
  34. class HqqHfQuantizer(HfQuantizer):
  35. """
  36. HQQ quantizer base HF class.
  37. nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading().
  38. """
  39. requires_calibration = False
  40. quantization_config: "HqqConfig"
  41. def __init__(self, quantization_config, **kwargs):
  42. if not is_hqq_available():
  43. raise ImportError(
  44. "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
  45. )
  46. super().__init__(quantization_config, **kwargs)
  47. self.dtype = None
  48. self.using_multi_gpu = False
  49. # Keys that are serialized specifically by hqq
  50. self.hqq_keys = HQQLinear(None, None).state_dict_keys() - {"bias"}
  51. def validate_environment(self, *args, **kwargs):
  52. if self.dtype is None:
  53. if "dtype" in kwargs:
  54. self.dtype = kwargs["dtype"]
  55. else:
  56. self.dtype = torch.float32
  57. logger.info("Setting dtype to torch.float32 as the default value since it was not specified.")
  58. device_map = kwargs.get("device_map")
  59. if isinstance(device_map, dict):
  60. if "cpu" in device_map.values() or "disk" in device_map.values():
  61. raise ValueError(
  62. "You are attempting to use an HQQ model with a device_map that contains a CPU or disk device."
  63. " This is not supported. Please remove the CPU or disk device from the device_map."
  64. )
  65. else:
  66. self.using_multi_gpu = len(set(device_map.values())) > 1
  67. # TODO: to remove
  68. # Kept here in case we see some interest in adding support for it
  69. # # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
  70. # def update_expected_keys(
  71. # self, model: "PreTrainedModel", expected_keys: list[str], loaded_keys: list[str]
  72. # ) -> list[str]:
  73. # if not self.pre_quantized:
  74. # return expected_keys
  75. # # Collects all quantizable (linear) layers
  76. # def _find_hqq_quantizable_layers(model, layers):
  77. # for name, module in model.named_children():
  78. # if isinstance(module, (torch.nn.Linear)):
  79. # layers.add(module.name)
  80. # _find_hqq_quantizable_layers(module, layers)
  81. # new_keys = set(expected_keys)
  82. # # Name modules
  83. # for name, module in model.named_modules():
  84. # module.name = name
  85. # # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
  86. # _valid_modules = set()
  87. # _find_hqq_quantizable_layers(model, _valid_modules)
  88. # # Remove skipped modules
  89. # _skipped_modules = set()
  90. # for _module in _valid_modules:
  91. # for _skip_module in model.config.quantization_config["skip_modules"]:
  92. # if _skip_module in _module:
  93. # _skipped_modules.add(_module)
  94. # _valid_modules -= _skipped_modules
  95. # # Append new expected layers based on _ref_keys
  96. # _ref_keys = HQQLinear(
  97. # linear_layer=None,
  98. # quant_config=None,
  99. # compute_dtype=torch.float16,
  100. # device="cpu",
  101. # del_orig=False,
  102. # ).state_dict_keys() - {"bias"}
  103. # # Clean-up
  104. # _rm_keys = set()
  105. # for key in new_keys:
  106. # if any(_module in key for _module in _valid_modules):
  107. # _rm_keys.add(key)
  108. # new_keys -= _rm_keys
  109. # # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
  110. # # Re-populate Linear/HQQLinear
  111. # for _module in _valid_modules:
  112. # if _module + ".weight" in loaded_keys:
  113. # new_keys.add(_module + ".weight")
  114. # else:
  115. # new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
  116. # if _module + ".bias" in loaded_keys:
  117. # new_keys.add(_module + ".bias")
  118. # return list(new_keys)
  119. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  120. module, _ = get_module_from_name(model, param_name)
  121. # Since we do not prepare the modules in advance, we need every param of the Linear layer to go through
  122. # `create_quantized_param`, even when `self.is_quantized == True`
  123. return isinstance(module, torch.nn.Linear)
  124. # TODO: to remove
  125. # def create_quantized_param(
  126. # self,
  127. # model: "PreTrainedModel",
  128. # param_value: "torch.Tensor",
  129. # param_name: str,
  130. # target_device: "torch.device",
  131. # **kwargs,
  132. # ):
  133. # module, tensor_name = get_module_from_name(model, param_name)
  134. # module_name = param_name.rsplit(".", 1)[0]
  135. # parent_module, node = get_module_from_name(model, module_name)
  136. # quant_config = model.config.quantization_config["quant_config"]
  137. # skip_modules = model.config.quantization_config["skip_modules"]
  138. # # In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
  139. # if any(skip_module in module.name for skip_module in skip_modules):
  140. # module.load_state_dict(
  141. # {tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True
  142. # )
  143. # return
  144. # # We need this hack as the model is not pre-prepared as an empty skeleton on meta device
  145. # if self.pre_quantized:
  146. # # Save them for later
  147. # if not hasattr(self, "hqq_params"):
  148. # self.hqq_params = defaultdict(dict)
  149. # self.hqq_params[module_name].update({tensor_name: param_value})
  150. # hqq_params = self.hqq_params[module_name]
  151. # # If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
  152. # # hqq does not support it...)
  153. # if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None):
  154. # hqq_layer = HQQLinear(
  155. # linear_layer=None,
  156. # quant_config=None,
  157. # compute_dtype=self.dtype,
  158. # device=target_device,
  159. # del_orig=False,
  160. # )
  161. # hqq_layer.load_state_dict(hqq_params)
  162. # if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
  163. # hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
  164. # if self.using_multi_gpu:
  165. # hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
  166. # setattr(parent_module, node, hqq_layer)
  167. # del self.hqq_params[module_name], module
  168. # return
  169. # # Load param in the module (without caring about device or dtype, it will be changed later)
  170. # module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)
  171. # # If both the weight and bias have already been loaded, time to quantize!
  172. # module_is_ready = module.weight.device.type != "meta" and (
  173. # module.bias is None or module.bias.device.type != "meta"
  174. # )
  175. # if module_is_ready:
  176. # module_tag = ".".join(module.name.split(".")[-2:])
  177. # if "weight_quant_params" in quant_config:
  178. # module_quant_config = quant_config
  179. # elif module_tag in quant_config:
  180. # module_quant_config = quant_config[module_tag]
  181. # hqq_layer = HQQLinear(
  182. # module,
  183. # quant_config=module_quant_config,
  184. # compute_dtype=self.dtype,
  185. # device=target_device,
  186. # del_orig=True,
  187. # )
  188. # if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
  189. # hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
  190. # if self.using_multi_gpu:
  191. # hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
  192. # setattr(parent_module, node, hqq_layer)
  193. def _patch_layer_for_multigpu(self, hqq_layer):
  194. def forward_with_device(self, x):
  195. out = torch.matmul(x.to(self.device), self.dequantize().t())
  196. if self.bias is not None:
  197. out += self.bias
  198. return out
  199. hqq_layer.forward = lambda x: forward_with_device(hqq_layer, x)
  200. return hqq_layer
  201. def _process_model_before_weight_loading(
  202. self,
  203. model: "PreTrainedModel",
  204. **kwargs,
  205. ):
  206. # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param().
  207. # prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config)
  208. model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config)
  209. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  210. setattr(model, "is_hqq_quantized", True)
  211. setattr(model, "is_hqq_serializable", self.is_serializable())
  212. return model
  213. def is_serializable(self):
  214. return True
  215. @property
  216. def is_trainable(self) -> bool:
  217. return True