eetq.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2024 NetEase, Inc. and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..core_model_loading import ConversionOps
  15. from ..quantizers.quantizers_utils import should_convert_module
  16. from ..utils import is_torch_available, logging
  17. if is_torch_available():
  18. import torch
  19. import torch.nn as nn
  20. logger = logging.get_logger(__name__)
  21. class EetqQuantize(ConversionOps):
  22. def __init__(self, hf_quantizer):
  23. self.hf_quantizer = hf_quantizer
  24. def convert(
  25. self, input_dict: dict[str, list[torch.Tensor]], full_layer_name: str | None = None, **kwargs
  26. ) -> dict[str, torch.Tensor]:
  27. _, value = tuple(input_dict.items())[0]
  28. value = value[0]
  29. value_device = value.device
  30. int8_weight = torch.t(value).contiguous().cpu()
  31. int8_weight, scales = eetq_kernels_hub.quant_weights(int8_weight, torch.int8, False)
  32. int8_weight = int8_weight.to(value_device)
  33. scales = scales.to(value_device)
  34. return {full_layer_name: int8_weight, f"{full_layer_name}_scales": scales}
  35. class EetqLinearMMFunction(torch.autograd.Function):
  36. @staticmethod
  37. def forward(ctx, x, weight, scales, bias=None):
  38. # The forward pass can use ctx.
  39. ctx.save_for_backward(x, weight, scales, bias)
  40. output = eetq_kernels_hub.w8_a16_gemm(x, weight, scales)
  41. output = output + bias if bias is not None else output
  42. return output
  43. @staticmethod
  44. def backward(ctx, grad_output):
  45. input, weight, scales, bias = ctx.saved_tensors
  46. identity = torch.eye(weight.shape[0]).to(weight.device).to(input.dtype)
  47. # Dequantize the weight
  48. weight = eetq_kernels_hub.w8_a16_gemm(identity, weight, scales)
  49. if ctx.needs_input_grad[0]:
  50. # 2D matrix multiplication, unsqueeze to 3D
  51. grad_input = grad_output.squeeze(0).matmul(weight.transpose(0, 1)).unsqueeze(0)
  52. return grad_input, None, None, None
  53. class EetqLinear(nn.Module):
  54. def __init__(self, in_features, out_features, dtype=torch.int8, bias=False):
  55. super().__init__()
  56. self.weight = nn.Parameter(torch.empty((in_features, out_features), dtype=dtype), requires_grad=False)
  57. self.weight_scales = nn.Parameter(torch.empty((out_features), dtype=torch.float16))
  58. if bias:
  59. self.bias = nn.Parameter(torch.empty((out_features), dtype=torch.float16))
  60. else:
  61. self.bias = None
  62. def forward(self, input):
  63. output = EetqLinearMMFunction.apply(input, self.weight, self.weight_scales, self.bias)
  64. return output
  65. def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = None, pre_quantized=False):
  66. """
  67. A helper function to replace all `torch.nn.Linear` modules by `EetqLinear` modules.
  68. Parameters:
  69. model (`torch.nn.Module`):
  70. Input model or `torch.nn.Module` as the function is run recursively.
  71. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
  72. Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
  73. for numerical stability reasons.
  74. """
  75. from .hub_kernels import get_kernel
  76. global eetq_kernels_hub
  77. eetq_kernels_hub = get_kernel("kernels-community/quantization-eetq")
  78. has_been_replaced = False
  79. # we need this to correctly materialize the weights during quantization
  80. module_kwargs = {} if pre_quantized else {"dtype": None}
  81. for module_name, module in model.named_modules():
  82. if not should_convert_module(module_name, modules_to_not_convert):
  83. continue
  84. with torch.device("meta"):
  85. if isinstance(module, nn.Linear):
  86. new_module = EetqLinear(
  87. module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs
  88. )
  89. model.set_submodule(module_name, new_module)
  90. has_been_replaced = True
  91. if not has_been_replaced:
  92. logger.warning(
  93. "You are loading your model using eetq but no linear modules were found in your model."
  94. " Please double check your model architecture, or submit an issue on github if you think this is"
  95. " a bug."
  96. )
  97. return model