fbgemm_fp8.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from functools import lru_cache
  15. from ..activations import ACT2FN
  16. from ..core_model_loading import ConversionOps
  17. from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
  18. from ..utils import (
  19. is_accelerate_available,
  20. is_fbgemm_gpu_available,
  21. is_torch_available,
  22. is_torch_xpu_available,
  23. logging,
  24. )
  25. if is_torch_available():
  26. import torch
  27. from torch import nn
  28. if is_accelerate_available():
  29. from accelerate import init_empty_weights
  30. _is_torch_xpu_available = is_torch_xpu_available()
  31. if is_fbgemm_gpu_available() and not _is_torch_xpu_available:
  32. import fbgemm_gpu.experimental.gen_ai # noqa: F401
  33. logger = logging.get_logger(__name__)
  34. class FbgemmFp8Quantize(ConversionOps):
  35. def __init__(self, hf_quantizer):
  36. self.hf_quantizer = hf_quantizer
  37. def convert(
  38. self,
  39. input_dict: dict[str, torch.Tensor | list[torch.Tensor]],
  40. model: torch.nn.Module | None = None,
  41. **kwargs,
  42. ) -> dict[str, torch.Tensor]:
  43. target_key, value = tuple(input_dict.items())[0]
  44. value = value[0]
  45. from ..integrations import FbgemmFp8Llama4TextExperts
  46. module, tensor_name = get_module_from_name(model, target_key)
  47. if isinstance(module, FbgemmFp8Llama4TextExperts):
  48. if tensor_name == "gate_up_proj":
  49. # Process each expert separately
  50. # Transpose the second and third dimension
  51. transposed_param = value.transpose(1, 2)
  52. # Reshape to 2D for quantization
  53. original_shape = transposed_param.shape
  54. flattened_param = transposed_param.reshape(-1, original_shape[-1])
  55. # Quantize using per row instead of per column
  56. new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)
  57. # Reshape back to original dimensions
  58. new_value = new_value_flat.reshape(original_shape)
  59. new_value = new_value.transpose(1, 2)
  60. weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
  61. elif tensor_name == "down_proj":
  62. # Process each expert separately
  63. # Transpose the weights for proper quantization
  64. transposed_param = value.transpose(1, 2)
  65. # Reshape to 2D for quantization
  66. original_shape = transposed_param.shape
  67. flattened_param = transposed_param.reshape(-1, original_shape[-1])
  68. # Quantize using per column
  69. new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)
  70. # Reshape back to original dimensions
  71. new_value = new_value_flat.reshape(original_shape)
  72. new_value = new_value.transpose(1, 2)
  73. weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
  74. else:
  75. new_value, weight_scale = quantize_fp8_per_row(value)
  76. weight_scale = torch.nn.Parameter(weight_scale.view(weight_scale.shape[0], 1))
  77. return {target_key: torch.nn.Parameter(new_value), f"{target_key}_scale": weight_scale}
  78. class FbgemmFp8Linear(torch.nn.Linear):
  79. def __init__(self, in_features, out_features, bias, dtype=torch.float8_e4m3fn):
  80. super().__init__(in_features, out_features, bias)
  81. self.in_features = in_features
  82. self.out_features = out_features
  83. self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=dtype))
  84. self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=torch.float32))
  85. self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
  86. if bias:
  87. self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=torch.float32))
  88. else:
  89. self.bias = None
  90. def forward(self, x):
  91. # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
  92. output_shape = (*x.shape[:-1], -1)
  93. # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
  94. # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
  95. x_quantized, x_scale = quantize_fp8_per_row(x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub)
  96. # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
  97. # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
  98. # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
  99. weight_scale_float32 = self.weight_scale.to(torch.float32)
  100. if _is_torch_xpu_available:
  101. output = torch._scaled_mm(
  102. x_quantized,
  103. self.weight.t(),
  104. scale_a=x_scale.unsqueeze(-1),
  105. scale_b=weight_scale_float32.t(),
  106. out_dtype=x.dtype,
  107. bias=self.bias,
  108. )
  109. else:
  110. output = torch.ops.fbgemm.f8f8bf16_rowwise(
  111. x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
  112. )
  113. output = output + self.bias if self.bias is not None else output
  114. # Hacky for now, we have the output to the device of x
  115. output = output.to(x.device)
  116. output = output.reshape(output_shape)
  117. del x_quantized, x_scale
  118. return output
  119. class FbgemmFp8Llama4TextExperts(nn.Module):
  120. def __init__(self, config, dtype=torch.float32):
  121. super().__init__()
  122. self.num_experts = config.num_local_experts
  123. self.intermediate_size = config.intermediate_size
  124. self.hidden_size = config.hidden_size
  125. self.expert_dim = self.intermediate_size
  126. self.act_fn = ACT2FN[config.hidden_act]
  127. # Register FP8 buffers for gate_up_proj
  128. self.gate_up_proj = torch.nn.Parameter(
  129. torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn)
  130. )
  131. self.gate_up_proj_scale = torch.nn.Parameter(
  132. torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32)
  133. )
  134. # Register FP8 buffers for down_proj
  135. self.down_proj = torch.nn.Parameter(
  136. torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn)
  137. )
  138. self.down_proj_scale = torch.nn.Parameter(
  139. torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32)
  140. )
  141. # Register input scale upper bound
  142. self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
  143. def forward(self, hidden_states):
  144. """
  145. Args:
  146. hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
  147. Returns:
  148. torch.Tensor: (batch_size * token_num, hidden_size)
  149. """
  150. # Reshape hidden states for expert computation
  151. hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
  152. num_tokens = None
  153. # Pre-allocate tensor for all expert outputs with same shape as hidden_states
  154. next_states = torch.empty_like(hidden_states)
  155. for i in range(self.num_experts):
  156. # Extract expert's hidden states
  157. expert_hidden = hidden_states[i]
  158. expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
  159. # Quantize for this expert
  160. expert_quantized, expert_scale = quantize_fp8_per_row(
  161. expert_hidden_reshaped, num_tokens, self.input_scale_ub
  162. )
  163. sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
  164. gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
  165. if _is_torch_xpu_available:
  166. gate = torch._scaled_mm(
  167. expert_quantized,
  168. self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous().t(),
  169. scale_a=expert_scale.unsqueeze(-1),
  170. scale_b=gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous().t(),
  171. out_dtype=hidden_states.dtype,
  172. )
  173. up = torch._scaled_mm(
  174. expert_quantized,
  175. self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous().t(),
  176. scale_a=expert_scale.unsqueeze(-1),
  177. scale_b=gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous().t(),
  178. out_dtype=hidden_states.dtype,
  179. )
  180. else:
  181. gate = torch.ops.fbgemm.f8f8bf16_rowwise(
  182. expert_quantized,
  183. self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
  184. expert_scale,
  185. gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
  186. use_fast_accum=True,
  187. )
  188. up = torch.ops.fbgemm.f8f8bf16_rowwise(
  189. expert_quantized,
  190. self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
  191. expert_scale,
  192. gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
  193. use_fast_accum=True,
  194. )
  195. activated = up * self.act_fn(gate)
  196. activated_quantized, activated_scale = quantize_fp8_per_row(activated, num_tokens, self.input_scale_ub)
  197. down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
  198. if _is_torch_xpu_available:
  199. expert_output = torch._scaled_mm(
  200. activated_quantized,
  201. self.down_proj[i].transpose(0, 1).contiguous(),
  202. scale_a=activated_scale.unsqueeze(-1),
  203. scale_b=down_proj_scale_float32[i].view(-1, 1).contiguous().t(),
  204. out_dtype=hidden_states.dtype,
  205. )
  206. else:
  207. expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
  208. activated_quantized,
  209. self.down_proj[i].transpose(0, 1).contiguous(),
  210. activated_scale,
  211. down_proj_scale_float32[i].view(-1, 1).contiguous(),
  212. use_fast_accum=True,
  213. )
  214. next_states[i] = expert_output
  215. next_states = next_states.to(hidden_states.device)
  216. return next_states.view(-1, self.hidden_size)
  217. @lru_cache(maxsize=1)
  218. def get_quantize_fp8_per_row():
  219. if _is_torch_xpu_available:
  220. from .hub_kernels import get_kernel
  221. return get_kernel("kernels-community/fp8-fbgemm").quantize_fp8_per_row
  222. return torch.ops.fbgemm.quantize_fp8_per_row
  223. def replace_with_fbgemm_fp8_linear(
  224. model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False, tp_plan=None
  225. ):
  226. """
  227. A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
  228. This will enable running your models using high performance fp8 kernel from FBGEMM library.
  229. Parameters:
  230. model (`torch.nn.Module`):
  231. Input model or `torch.nn.Module` as the function is run recursively.
  232. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
  233. Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons.
  234. quantization_config (`FbgemmFp8Config`):
  235. The quantization config object that contains the quantization parameters.
  236. pre_quantized (`book`, defaults to `False`):
  237. Whether the model is pre-quantized or not
  238. """
  239. global quantize_fp8_per_row
  240. quantize_fp8_per_row = get_quantize_fp8_per_row()
  241. has_been_replaced = False
  242. module_kwargs = {} if pre_quantized else {"dtype": None}
  243. for module_name, module in model.named_modules():
  244. if not should_convert_module(module_name, modules_to_not_convert):
  245. continue
  246. new_module = None
  247. with init_empty_weights(include_buffers=True):
  248. if module.__class__.__name__ == "Llama4TextExperts":
  249. # TODO: make sure tp works later
  250. # if tp_plan is not None:
  251. # tp_key = re.sub(r"\d+", "*", f"{module_name}.down_proj_scale")
  252. # tp_plan[tp_key] = None
  253. text_config = getattr(model.config, "text_config", model.config)
  254. new_module = FbgemmFp8Llama4TextExperts(text_config or model.config)
  255. elif isinstance(module, nn.Linear):
  256. new_module = FbgemmFp8Linear(
  257. module.in_features,
  258. module.out_features,
  259. module.bias is not None,
  260. **module_kwargs,
  261. )
  262. new_module.requires_grad_(False)
  263. if new_module is None:
  264. continue
  265. model.set_submodule(module_name, new_module)
  266. has_been_replaced = True
  267. if not has_been_replaced:
  268. logger.warning(
  269. "You are loading your model using FP8 quantization but no linear modules were found in your model."
  270. " Please double check your model architecture, or submit an issue on github if you think this is"
  271. " a bug."
  272. )
  273. return model