quanto.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..core_model_loading import ConversionOps
  15. from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
  16. from ..utils import is_torch_available, logging
  17. if is_torch_available():
  18. import torch
  19. import torch.nn as nn
  20. logger = logging.get_logger(__name__)
  21. class QuantoQuantize(ConversionOps):
  22. def __init__(self, hf_quantizer):
  23. self.hf_quantizer = hf_quantizer
  24. def convert(
  25. self,
  26. input_dict: dict[str, list[torch.Tensor]],
  27. model: torch.nn.Module | None = None,
  28. full_layer_name: str | None = None,
  29. missing_keys: list[str] | None = None,
  30. **kwargs,
  31. ) -> dict[str, torch.Tensor]:
  32. _, value = tuple(input_dict.items())[0]
  33. value = value[0]
  34. from ..modeling_utils import _load_parameter_into_model
  35. _load_parameter_into_model(model, full_layer_name, value)
  36. module, _ = get_module_from_name(model, full_layer_name)
  37. # Need to set those to a specific value, otherwise they will remain on meta device ...
  38. module.input_scale = torch.ones(module.input_scale.shape)
  39. module.output_scale = torch.ones(module.output_scale.shape)
  40. # quantize
  41. module.freeze()
  42. module.weight.requires_grad = False
  43. module._is_hf_initialized = True
  44. # need to discard some missing keys we already updated the module in freeze.
  45. module_name = full_layer_name.rsplit(".", 1)[0]
  46. missing_keys.discard(f"{module_name}.weight")
  47. missing_keys.discard(f"{module_name}.input_scale")
  48. missing_keys.discard(f"{module_name}.output_scale")
  49. return {}
  50. def replace_with_quanto_layers(
  51. model,
  52. quantization_config=None,
  53. modules_to_not_convert: list[str] | None = None,
  54. ):
  55. """
  56. Public method that recursively replaces the Linear layers of the given model with Quanto quantized layers.
  57. Returns the converted model and a boolean that indicates if the conversion has been successful or not.
  58. Args:
  59. model (`torch.nn.Module`):
  60. The model to convert, can be any `torch.nn.Module` instance.
  61. quantization_config (`QuantoConfig`, defaults to `None`):
  62. The quantization config object that contains the quantization parameters.
  63. modules_to_not_convert (`list`, *optional*, defaults to `None`):
  64. A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
  65. converted.
  66. """
  67. from optimum.quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
  68. w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
  69. a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
  70. has_been_replaced = False
  71. for module_name, module in model.named_modules():
  72. if not should_convert_module(module_name, modules_to_not_convert):
  73. continue
  74. with torch.device("meta"):
  75. new_module = None
  76. if isinstance(module, nn.Linear):
  77. new_module = QLinear(
  78. in_features=module.in_features,
  79. out_features=module.out_features,
  80. bias=module.bias is not None,
  81. dtype=module.weight.dtype,
  82. weights=w_mapping[quantization_config.weights],
  83. activations=a_mapping[quantization_config.activations],
  84. )
  85. elif isinstance(module, torch.nn.LayerNorm) and quantization_config.activations is not None:
  86. new_module = QLayerNorm(
  87. module.normalized_shape,
  88. module.eps,
  89. module.elementwise_affine,
  90. module.bias is not None,
  91. activations=a_mapping[quantization_config.activations],
  92. )
  93. if new_module is not None:
  94. has_been_replaced = True
  95. model.set_submodule(module_name, new_module)
  96. if not has_been_replaced:
  97. logger.warning(
  98. "You are loading your model using quanto but no linear modules were found in your model."
  99. " Please double check your model architecture, or submit an issue on github if you think this is"
  100. " a bug."
  101. )
  102. return model