torchao.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. import types
  16. import torch
  17. from transformers.utils import logging
  18. from transformers.utils.import_utils import is_torch_available, is_torchao_available
  19. if is_torch_available():
  20. from ..core_model_loading import ConversionOps
  21. from ..quantizers.quantizers_utils import get_module_from_name
  22. if is_torchao_available():
  23. from torchao.prototype.safetensors.safetensors_support import (
  24. unflatten_tensor_state_dict,
  25. )
  26. from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
  27. logger = logging.get_logger(__name__)
  28. def _quantization_type(weight):
  29. from torchao.dtypes import AffineQuantizedTensor
  30. from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
  31. if isinstance(weight, AffineQuantizedTensor):
  32. return f"{weight.__class__.__name__}({weight._quantization_type()})"
  33. if isinstance(weight, LinearActivationQuantizedTensor):
  34. return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
  35. def _linear_extra_repr(self):
  36. weight = _quantization_type(self.weight)
  37. if weight is None:
  38. return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
  39. else:
  40. return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
  41. class TorchAoQuantize(ConversionOps):
  42. def __init__(self, hf_quantizer):
  43. self.hf_quantizer = hf_quantizer
  44. def _quantize(self, module, config, *args, **kwargs):
  45. """Run quantize_, moving to CUDA first if CPU offloading is active.
  46. Some torchao quantization ops (e.g. int4 packing) only have CUDA kernels.
  47. When a layer is destined for CPU (e.g. CPU offloading), we temporarily move
  48. it to CUDA for quantization, then move the result back to CPU.
  49. """
  50. from torchao.quantization import quantize_
  51. target_device = next(module.parameters()).device
  52. if self.hf_quantizer.offload_to_cpu and target_device.type == "cpu":
  53. module.to("cuda")
  54. quantize_(module, config, *args, **kwargs)
  55. module.to("cpu")
  56. else:
  57. quantize_(module, config, *args, **kwargs)
  58. def convert(
  59. self,
  60. input_dict: dict[str, torch.Tensor],
  61. model: torch.nn.Module | None = None,
  62. full_layer_name: str | None = None,
  63. missing_keys=None,
  64. **kwargs,
  65. ) -> dict[str, torch.Tensor]:
  66. _, value = tuple(input_dict.items())[0]
  67. value = value[0] if isinstance(value, list) else value
  68. module, tensor_name = get_module_from_name(model, full_layer_name)
  69. module._parameters[tensor_name] = torch.nn.Parameter(value, requires_grad=value.requires_grad)
  70. # if we are quantizing tied parameters, to avoid tying the quantized weights
  71. # the correct order to do it is
  72. # 1. load the weight to model
  73. # 2. run tie_weights to populate the weights
  74. # 3. quantize
  75. input_embed = model.get_input_embeddings()
  76. is_embedding_param = id(module) == id(input_embed)
  77. untie_embedding_weights = self.hf_quantizer.quantization_config.untie_embedding_weights
  78. if untie_embedding_weights and is_embedding_param:
  79. setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
  80. from torchao.quantization import FqnToConfig
  81. config = self.hf_quantizer.quantization_config.get_apply_tensor_subclass()
  82. if isinstance(config, FqnToConfig):
  83. module_fqn, top_level_param_name = full_layer_name.rsplit(".", 1)
  84. c = None
  85. if full_layer_name in config.fqn_to_config:
  86. assert not module_fqn.startswith("re:"), (
  87. "param fqn should not start with`re:`, which is used for specifying regex"
  88. )
  89. c = config.module_fqn_to_config[full_layer_name]
  90. elif module_fqn in config.fqn_to_config:
  91. assert not module_fqn.startswith("re:"), (
  92. "module fqn should not start with`re:`, which is used for specifying regex"
  93. )
  94. c = config.module_fqn_to_config[module_fqn]
  95. # regex match module and param
  96. else:
  97. for maybe_module_fqn_pattern in config.fqn_to_config:
  98. # if key doesn't start with re, it is an exact fqn key, so we don't regex match
  99. if not maybe_module_fqn_pattern.startswith("re:"):
  100. continue
  101. # see if param matches first
  102. elif re.fullmatch(maybe_module_fqn_pattern[3:], full_layer_name):
  103. c = config.module_fqn_to_config[maybe_module_fqn_pattern]
  104. break
  105. elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
  106. # we'll apply the config for first fully matched pattern
  107. c = config.module_fqn_to_config[maybe_module_fqn_pattern]
  108. break
  109. else:
  110. c = config.module_fqn_to_config.get("_default", None)
  111. if c is not None:
  112. if top_level_param_name == "weight":
  113. if is_embedding_param and untie_embedding_weights:
  114. lm_head = module.weight.clone()
  115. # we can apply the module config directly
  116. self._quantize(module, c, (lambda x, fqn: True))
  117. missing_keys.discard(full_layer_name)
  118. module._is_hf_initialized = True
  119. # torchao quantizes weights into a module but some models access the weight directly
  120. # (e.g. module.o_proj.weight). The _is_hf_initialized flag is set at the module
  121. # level only, so we also set it on each parameter to prevent _init_weights from
  122. # calling normal_() on already-quantized Float8Tensors.
  123. for param in module.parameters(recurse=False):
  124. param._is_hf_initialized = True
  125. return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {}
  126. else:
  127. # need to apply to custom param name
  128. custom_param_fqn_config = FqnToConfig({top_level_param_name: c})
  129. self._quantize(module, custom_param_fqn_config, filter_fn=None)
  130. missing_keys.discard(full_layer_name)
  131. module._is_hf_initialized = True
  132. for param in module.parameters(recurse=False):
  133. param._is_hf_initialized = True
  134. return {}
  135. return {full_layer_name: value}
  136. if is_embedding_param and untie_embedding_weights:
  137. lm_head = module.weight.clone()
  138. self._quantize(module, self.hf_quantizer.quantization_config.get_apply_tensor_subclass())
  139. missing_keys.discard(full_layer_name)
  140. module._is_hf_initialized = True
  141. for param in module.parameters(recurse=False):
  142. param._is_hf_initialized = True
  143. return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {}
  144. class TorchAoDeserialize(ConversionOps):
  145. def __init__(self, hf_quantizer):
  146. self.hf_quantizer = hf_quantizer
  147. def convert(
  148. self,
  149. input_dict: dict[str, torch.Tensor],
  150. source_patterns: list[str] | None = None,
  151. model: torch.nn.Module | None = None,
  152. full_layer_name: str | None = None,
  153. missing_keys=None,
  154. **kwargs,
  155. ) -> dict[str, torch.Tensor]:
  156. """
  157. Consolidates tensor subclass components before reconstructing the object
  158. For example:
  159. input_dict: {
  160. "_weight_qdata": torch.Tensor,
  161. "_weight_scale": torch.Tensor,
  162. }
  163. full_layer_name: "model.layers.0.self_attn.k_proj.weight"
  164. Given this, we reconstruct a Float8Tensor instance using the qdata and scale
  165. and return it as a dictionary with the full_layer_name as the key and the recovered
  166. Float8Tensor instance as the value.
  167. """
  168. is_unsafe_serialization = list(input_dict.keys())[0] not in source_patterns
  169. param_data = {}
  170. layer_name = ".".join(full_layer_name.split(".")[:-1])
  171. if is_unsafe_serialization:
  172. if isinstance(input_dict["weight"], list):
  173. weight = input_dict["weight"][0]
  174. else:
  175. weight = input_dict["weight"]
  176. else:
  177. for suffix in input_dict.keys():
  178. if len(input_dict[suffix]) != 1:
  179. raise ValueError(
  180. f"Expected a single tensor for {suffix} but got {len(input_dict[suffix])} tensors instead"
  181. )
  182. param_data[f"{layer_name}.{suffix}"] = input_dict[suffix][0]
  183. # If it's unsafe-serialized (i.e. not safetensors), no need for anything
  184. if is_unsafe_serialization:
  185. return {full_layer_name: weight}
  186. elif not is_metadata_torchao(self.hf_quantizer.metadata):
  187. raise ValueError("Invalid torchao safetensors metadata")
  188. unflattened_state_dict, leftover_state_dict = unflatten_tensor_state_dict(
  189. param_data, self.hf_quantizer.metadata
  190. )
  191. assert not leftover_state_dict # there should be no unprocessed tensors
  192. new_param = unflattened_state_dict[full_layer_name]
  193. module, _ = get_module_from_name(model, full_layer_name)
  194. # Add repr to the module
  195. if isinstance(module, torch.nn.Linear):
  196. module.extra_repr = types.MethodType(_linear_extra_repr, module)
  197. return {full_layer_name: new_param}