base.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from abc import ABC, abstractmethod
  15. from typing import TYPE_CHECKING, Any
  16. from ..utils import is_torch_available, logging
  17. from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod
  18. from .quantizers_utils import get_module_from_name
  19. if TYPE_CHECKING:
  20. from torch.nn import ModuleList
  21. from ..modeling_utils import PreTrainedModel
  22. if is_torch_available():
  23. import torch
  24. if not TYPE_CHECKING:
  25. from torch.nn import ModuleList
  26. else:
  27. ModuleList = str
  28. logger = logging.get_logger(__file__)
  29. def get_keys_to_not_convert(model) -> list:
  30. r"""
  31. Function to automatically detect keys to not convert for usage like quantization. For example for CausalLM modules
  32. we may want to keep the lm_head in full precision for numerical stability reasons.
  33. """
  34. # remove tied weights
  35. tied_keys = set()
  36. if len(model.all_tied_weights_keys) > 0:
  37. tied_keys = set(model.all_tied_weights_keys.values()) | set(model.all_tied_weights_keys.keys())
  38. # remove last module
  39. last_module_key = {list(model.named_parameters())[-1][0]}
  40. # remove output emb
  41. output_emb_module = model.get_output_embeddings()
  42. output_emb_keys = {
  43. name
  44. for name, module in model.named_modules()
  45. if output_emb_module is not None and id(module) == id(output_emb_module)
  46. }
  47. modules_to_not_convert = tied_keys | last_module_key | output_emb_keys
  48. modules_to_not_convert = list({k.removesuffix(".weight") for k in modules_to_not_convert})
  49. return list(modules_to_not_convert)
  50. def _assign_is_quantized(model):
  51. from ..modeling_utils import PreTrainedModel
  52. for module in model.modules():
  53. if isinstance(module, PreTrainedModel):
  54. module.config._is_quantized = True
  55. class HfQuantizer(ABC):
  56. """
  57. Abstract class of the HuggingFace quantizer. Supports for now quantizing HF transformers models for inference and/or quantization.
  58. This class is used only for transformers.PreTrainedModel.from_pretrained and cannot be easily used outside the scope of that method
  59. yet.
  60. Attributes
  61. quantization_config (`transformers.utils.quantization_config.QuantizationConfigMixin`):
  62. The quantization config that defines the quantization parameters of your model that you want to quantize.
  63. requires_calibration (`bool`):
  64. Whether the quantization method requires to calibrate the model before using it.
  65. """
  66. requires_calibration = False
  67. def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
  68. self.quantization_config = quantization_config
  69. self.pre_quantized = kwargs.pop("pre_quantized", True)
  70. if not self.pre_quantized and self.requires_calibration:
  71. raise ValueError(
  72. f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized."
  73. f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to "
  74. f"pass `pre_quantized=True` while knowing what you are doing."
  75. )
  76. def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  77. """
  78. Some quantization methods require to explicitly set the dtype of the model to a
  79. target dtype. You need to override this method in case you want to make sure that behavior is
  80. preserved
  81. Args:
  82. dtype (`torch.dtype`):
  83. The input dtype that is passed in `from_pretrained`
  84. """
  85. return dtype
  86. def update_device_map(self, device_map: dict[str, Any] | None) -> dict[str, Any] | None:
  87. """
  88. Override this method if you want to pass a override the existing device map with a new
  89. one. E.g. for bitsandbytes, since `accelerate` is a hard requirement, if no device_map is
  90. passed, the device_map is set to `"auto"``
  91. Args:
  92. device_map (`Union[dict, str]`, *optional*):
  93. The device_map that is passed through the `from_pretrained` method.
  94. """
  95. return device_map
  96. def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
  97. return param.element_size()
  98. def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
  99. """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
  100. return max_memory
  101. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  102. """
  103. Check whether a given param needs to be quantized.
  104. """
  105. return False
  106. def validate_environment(self, *args, **kwargs):
  107. """
  108. This method is used to potentially check for potential conflicts with arguments that are
  109. passed in `from_pretrained`. You need to define it for all future quantizers that are integrated with transformers.
  110. If no explicit check are needed, simply return nothing.
  111. """
  112. return
  113. def update_tp_plan(self, config):
  114. "updates the tp plan for the scales"
  115. return config
  116. def update_ep_plan(self, config):
  117. "updates the tp plan for the scales"
  118. return config
  119. def _process_model_before_weight_loading(self, model, **kwargs):
  120. return model
  121. def preprocess_model(self, model: "PreTrainedModel", dtype=None, **kwargs):
  122. """
  123. Setting model attributes and/or converting model before weights loading. At this point
  124. the model should be initialized on the meta device so you can freely manipulate the skeleton
  125. of the model in order to replace modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`.
  126. Args:
  127. model (`~transformers.PreTrainedModel`):
  128. The model to quantize
  129. kwargs (`dict`, *optional*):
  130. The keyword arguments that are passed along `_process_model_before_weight_loading`.
  131. """
  132. setattr(model, "is_quantized", True)
  133. setattr(model, "quantization_method", self.quantization_config.quant_method)
  134. if self.pre_quantized:
  135. self._convert_model_for_quantization(model)
  136. self._process_model_before_weight_loading(model, **kwargs)
  137. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  138. return model
  139. def postprocess_model(self, model: "PreTrainedModel", **kwargs):
  140. """
  141. Post-process the model post weights loading.
  142. Make sure to override the abstract method `_process_model_after_weight_loading`.
  143. Args:
  144. model (`~transformers.PreTrainedModel`):
  145. The model to quantize
  146. kwargs (`dict`, *optional*):
  147. The keyword arguments that are passed along `_process_model_after_weight_loading`.
  148. """
  149. model.config.quantization_config = self.quantization_config
  150. if self.pre_quantized and getattr(self.quantization_config, "dequantize", False):
  151. self.remove_quantization_config(model)
  152. else:
  153. _assign_is_quantized(model)
  154. return self._process_model_after_weight_loading(model, **kwargs)
  155. def remove_quantization_config(self, model):
  156. """
  157. Remove the quantization config from the model.
  158. """
  159. if hasattr(model, "hf_quantizer"):
  160. del model.hf_quantizer
  161. if hasattr(model.config, "quantization_config"):
  162. del model.config.quantization_config
  163. if hasattr(model, "quantization_method"):
  164. del model.quantization_method
  165. model.is_quantized = False
  166. def dequantize(self, model, dtype=None):
  167. """
  168. Potentially dequantize the model to retrieve the original model, with some loss in accuracy / performance.
  169. Note not all quantization schemes support this.
  170. """
  171. if dtype is None:
  172. # using the same dtype we used to load the model. If we don't do that, we might have issues with modules we didn't quantize.
  173. # or we need to upcast everything to the same dtype
  174. dtype = model.config.dtype
  175. model = self._dequantize(model, dtype=dtype)
  176. self.remove_quantization_config(model)
  177. return model
  178. def _dequantize(self, model, dtype=None):
  179. raise NotImplementedError(
  180. f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
  181. )
  182. def get_param_name(self, param_name: str) -> str:
  183. """
  184. Override this method if you want to adjust the `param_name`.
  185. """
  186. return param_name
  187. @staticmethod
  188. def get_modules_to_not_convert(
  189. model: "PreTrainedModel",
  190. skip_modules: list[str] | None = None,
  191. keep_in_fp32_modules: list[str] | None = None,
  192. add_default_skips: bool = False,
  193. ):
  194. if skip_modules is None or add_default_skips:
  195. modules_to_not_convert = get_keys_to_not_convert(model)
  196. else:
  197. modules_to_not_convert = []
  198. if skip_modules is not None:
  199. modules_to_not_convert.extend(skip_modules)
  200. if keep_in_fp32_modules is not None:
  201. modules_to_not_convert.extend(keep_in_fp32_modules)
  202. modules_to_not_convert = list(set(modules_to_not_convert))
  203. return modules_to_not_convert
  204. @property
  205. def is_qat_trainable(self) -> bool:
  206. """Flag indicating whether the quantized model can carry out quantization aware training"""
  207. return False
  208. @property
  209. def is_compileable(self) -> bool:
  210. """Flag indicating whether the quantized model can be compiled"""
  211. return False
  212. def get_state_dict_and_metadata(self, model):
  213. """Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
  214. return None, {}
  215. @abstractmethod
  216. def is_serializable(self): ...
  217. @property
  218. @abstractmethod
  219. def is_trainable(self): ...
  220. def _convert_model_for_quantization(self, model):
  221. for name, module in model.named_modules():
  222. module_class_name = module.__class__.__name__
  223. if module_class_name in MODULES_TO_PATCH_FOR_QUANTIZATION and (
  224. self.quantization_config.quant_method
  225. in MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name]["quantization_methods"]
  226. ):
  227. with torch.device("meta"):
  228. parent_module, name = get_module_from_name(model, name)
  229. parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name]["module_name"](
  230. model.config.get_text_config()
  231. )
  232. def get_quantize_ops(self):
  233. raise NotImplementedError(
  234. f"{self.quantization_config.quant_method} is not available yet and will be supported soon."
  235. )
  236. def get_weight_conversions(self):
  237. return []
  238. class SequentialLlama4TextExperts(ModuleList):
  239. """
  240. A module that implements a compressed version of a list of expert modules.
  241. This is specifically designed to work with Llama4TextExperts in MoE layers.
  242. """
  243. def __init__(self, config):
  244. from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
  245. super().__init__([Llama4TextMLP(config) for _ in range(config.num_local_experts)])
  246. self.num_experts = config.num_local_experts
  247. def forward(
  248. self,
  249. hidden_states: "torch.Tensor",
  250. ) -> "torch.Tensor":
  251. hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
  252. routed_out = torch.zeros_like(hidden_states)
  253. for expert_idx in range(self.num_experts):
  254. routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx])
  255. return routed_out
  256. MODULES_TO_PATCH_FOR_QUANTIZATION = {
  257. "Llama4TextExperts": {
  258. "module_name": SequentialLlama4TextExperts,
  259. "quantization_methods": [
  260. QuantizationMethod.COMPRESSED_TENSORS,
  261. QuantizationMethod.BITS_AND_BYTES,
  262. ],
  263. }
  264. }