finegrained_fp8.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. import triton
  17. from torch.nn import functional as F
  18. from ..activations import ACT2FN
  19. from ..core_model_loading import ConversionOps, _IdentityOp
  20. from ..quantizers.quantizers_utils import should_convert_module
  21. from ..utils import logging
  22. from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import
  23. from .hub_kernels import lazy_load_kernel
  24. from .moe import ExpertsInterface, use_experts_implementation
  25. logger = logging.get_logger(__name__)
  26. _FP8_DTYPE = torch.float8_e4m3fn
  27. _FP8_MIN = torch.finfo(_FP8_DTYPE).min
  28. _FP8_MAX = torch.finfo(_FP8_DTYPE).max
  29. # DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM
  30. # TMA is an H100 hardware addition that allows applications to asynchronously and
  31. # bi-directionally transfer 1D-5D tensors between GPU global and shared memory
  32. _DEEPGEMM_M_ALIGNMENT = 128
  33. # Lazily-loaded finegrained-fp8 Triton kernel functions (populated by _load_triton_kernel)
  34. triton_fp8_matmul = None
  35. triton_fp8_act_quant = None
  36. triton_batched_fp8_matmul = None
  37. triton_grouped_fp8_matmul = None
  38. # _triton_available: None = not yet attempted, True = loaded, False = failed (won't retry)
  39. _triton_available = None
  40. # Lazily-loaded DeepGEMM kernel functions (populated by _load_deepgemm_kernel)
  41. deepgemm_fp8_matmul = None
  42. deepgemm_grouped_fp8_matmul = None
  43. deepgemm_per_token_cast_to_fp8 = None
  44. # _deepgemm_available: None = not yet attempted, True = loaded, False = failed (won't retry)
  45. _deepgemm_available = None
  46. def _load_triton_kernel():
  47. """Lazily load the finegrained-fp8 Triton kernel and extract functions.
  48. Uses the hub kernels lazy loading pattern. Raises an error if the kernel
  49. cannot be loaded or required functions are missing. Only attempts loading once.
  50. """
  51. global \
  52. _triton_available, \
  53. triton_fp8_act_quant, \
  54. triton_fp8_matmul, \
  55. triton_batched_fp8_matmul, \
  56. triton_grouped_fp8_matmul
  57. if _triton_available is not None:
  58. if not _triton_available:
  59. raise ImportError("finegrained-fp8 kernel is not available (previous load attempt failed).")
  60. return
  61. _triton_available = False # mark attempted before any early exit
  62. kernel = lazy_load_kernel("finegrained-fp8")
  63. triton_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul")
  64. triton_fp8_act_quant = getattr(kernel, "fp8_act_quant")
  65. triton_batched_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_batched")
  66. triton_grouped_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped")
  67. missing = [
  68. name
  69. for name, attr in [
  70. ("w8a8_fp8_matmul", triton_fp8_matmul),
  71. ("fp8_act_quant", triton_fp8_act_quant),
  72. ("w8a8_fp8_matmul_batched", triton_batched_fp8_matmul),
  73. ("w8a8_fp8_matmul_grouped", triton_grouped_fp8_matmul),
  74. ]
  75. if attr is None
  76. ]
  77. if missing:
  78. raise ImportError(
  79. f"finegrained-fp8 kernel is missing required functions: {', '.join(missing)}. "
  80. "Please update the `kernels` package (`pip install -U kernels`)."
  81. )
  82. _triton_available = True
  83. def _load_deepgemm_kernel():
  84. """Lazily load the DeepGEMM kernel and extract functions with proper names.
  85. Uses the hub kernels lazy loading pattern. Raises an error if the kernel
  86. cannot be loaded, required functions are missing, or the hardware is insufficient.
  87. Only attempts loading once.
  88. """
  89. global _deepgemm_available, deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8
  90. if _deepgemm_available is not None:
  91. if not _deepgemm_available:
  92. raise ImportError("DeepGEMM kernel is not available (previous load attempt failed).")
  93. return
  94. _deepgemm_available = False # mark attempted before any early exit
  95. # DeepGEMM requires CUDA and a compatible GPU
  96. if not torch.cuda.is_available():
  97. raise ImportError(
  98. "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`."
  99. )
  100. # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions
  101. major = torch.cuda.get_device_capability()[0]
  102. if major < 9:
  103. raise ImportError(
  104. f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device "
  105. f"has compute capability {major}.x. Use a different `experts_implementation`."
  106. )
  107. # DeepGEMM requires CUDA runtime ≥ 12.3.
  108. cuda_major, cuda_minor = get_cuda_runtime_version()
  109. if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3):
  110. raise ImportError(
  111. f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. "
  112. "Please upgrade your CUDA toolkit or use a different `experts_implementation`."
  113. )
  114. kernel = lazy_load_kernel("deep-gemm")
  115. deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt")
  116. deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous")
  117. deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8")
  118. missing = [
  119. name
  120. for name, attr in [
  121. ("fp8_gemm_nt", deepgemm_fp8_matmul),
  122. ("m_grouped_fp8_gemm_nt_contiguous", deepgemm_grouped_fp8_matmul),
  123. ("utils.per_token_cast_to_fp8", deepgemm_per_token_cast_to_fp8),
  124. ]
  125. if attr is None
  126. ]
  127. if missing:
  128. raise ImportError(
  129. f"DeepGEMM kernel is missing required functions: {', '.join(missing)}. "
  130. "Please update the `kernels` package (`pip install -U kernels`)."
  131. )
  132. _deepgemm_available = True
  133. def w8a8_fp8_matmul(
  134. A: torch.Tensor,
  135. B: torch.Tensor,
  136. As: torch.Tensor,
  137. Bs: torch.Tensor,
  138. block_size: list[int],
  139. output_dtype: torch.dtype = torch.float32,
  140. ) -> torch.Tensor:
  141. """FP8 matmul: C = dequant(A, As) @ dequant(B, Bs)^T.
  142. Supports both per-tensor and block-wise quantization:
  143. - block_size=None or block_size=[N, K]: per-tensor mode (As is scalar/per-row, Bs is scalar)
  144. - block_size=[block_n, block_k]: block-wise mode (As and Bs are per-block scale grids)
  145. Dispatch order:
  146. 1. DeepGEMM (Hopper+, block_size 128x128) if available
  147. 2. Triton finegrained-fp8 kernel (universal fallback)
  148. Args:
  149. A: (M, K) float8_e4m3fn — quantized activations
  150. B: (N, K) float8_e4m3fn — quantized weights
  151. As: block-wise: (M, K//block_k) float32; per-tensor: (M,) per-row scales
  152. Bs: block-wise: (N//block_n, K//block_k) float32; per-tensor: scalar or (1,) single weight scale
  153. block_size: [block_n, block_k] for block-wise quantization, or None/[N, K] for per-tensor
  154. output_dtype: desired output dtype
  155. """
  156. if block_size is not None and block_size[0] == block_size[1] == 128:
  157. try:
  158. _load_deepgemm_kernel()
  159. global deepgemm_fp8_matmul
  160. except ImportError:
  161. logger.warning_once(
  162. "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. "
  163. "To use DeepGEMM FP8 matmul, ensure you have a Hopper (SM90+) or newer GPU with CUDA runtime 12.3+, "
  164. "and that the `kernels` package is installed and up to date (`pip install -U kernels`)."
  165. )
  166. else:
  167. # 3-6x faster than Triton
  168. A_2d = A.view(-1, A.shape[-1])
  169. As_2d = As.view(-1, As.shape[-1])
  170. output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype)
  171. deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output)
  172. return output.view(A.shape[:-1] + (B.shape[0],))
  173. _load_triton_kernel()
  174. global triton_fp8_matmul
  175. return triton_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
  176. class FP8Linear(nn.Linear):
  177. def __init__(
  178. self,
  179. in_features: int,
  180. out_features: int,
  181. block_size: tuple[int, int] | None = None,
  182. activation_scheme: str = "dynamic",
  183. has_bias: bool = False,
  184. dtype=_FP8_DTYPE,
  185. ):
  186. super().__init__(in_features, out_features)
  187. self.has_bias = has_bias
  188. self.block_size = block_size
  189. self.activation_scheme = activation_scheme
  190. self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
  191. if self.block_size is None:
  192. # If block size is None, it means that we are doing per-tensor quantization
  193. self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
  194. else:
  195. scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0]
  196. scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1]
  197. self.weight_scale_inv = nn.Parameter(
  198. torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
  199. )
  200. if self.activation_scheme == "static":
  201. self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
  202. else:
  203. self.register_parameter("activation_scale", None)
  204. if self.has_bias:
  205. self.bias = nn.Parameter(torch.empty(self.out_features))
  206. else:
  207. self.register_parameter("bias", None)
  208. def forward(self, input: torch.Tensor) -> torch.Tensor:
  209. if self.weight.element_size() > 1:
  210. return F.linear(input, self.weight, self.bias)
  211. if isinstance(self.weight, torch.distributed.tensor.DTensor):
  212. weight = self.weight._local_tensor.contiguous()
  213. scale_inv = self.weight_scale_inv._local_tensor.contiguous()
  214. else:
  215. # why wouldn't it be contiguous?
  216. weight = self.weight.contiguous()
  217. scale_inv = self.weight_scale_inv.contiguous()
  218. if self.activation_scheme == "dynamic":
  219. _load_triton_kernel()
  220. global triton_fp8_act_quant
  221. qinput, scale = triton_fp8_act_quant(
  222. input, self.block_size[1] if self.block_size is not None else input.shape[-1]
  223. )
  224. elif self.activation_scheme == "static":
  225. scale = self.activation_scale.to(torch.float32)
  226. qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
  227. else:
  228. raise NotImplementedError(f"Unsupported activation scheme: {self.activation_scheme}")
  229. output = w8a8_fp8_matmul(
  230. qinput,
  231. weight,
  232. scale,
  233. scale_inv,
  234. self.block_size,
  235. output_dtype=input.dtype,
  236. )
  237. if self.bias is not None:
  238. output = output + self.bias
  239. return output.to(dtype=input.dtype)
  240. def fp8_batched_mm_experts_forward(
  241. self: torch.nn.Module,
  242. hidden_states: torch.Tensor,
  243. top_k_index: torch.Tensor,
  244. top_k_weights: torch.Tensor,
  245. ) -> torch.Tensor:
  246. if self.activation_scheme == "static":
  247. raise NotImplementedError(
  248. "batched_mm experts dispatch does not support activation_scheme='static'. "
  249. "Use the default eager dispatch or switch to activation_scheme='dynamic'."
  250. )
  251. _load_triton_kernel()
  252. global triton_batched_fp8_matmul
  253. device = hidden_states.device
  254. num_top_k = top_k_index.size(-1)
  255. num_tokens = hidden_states.size(0)
  256. hidden_dim = hidden_states.size(-1)
  257. # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k)
  258. token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,)
  259. sample_weights = top_k_weights.reshape(-1) # (S,)
  260. expert_ids = top_k_index.reshape(-1) # (S,)
  261. # Get current hidden states for selected samples
  262. selected_hidden_states = hidden_states[token_idx]
  263. # --- Up projection per expert (FP8 batched) ---
  264. proj_out = triton_batched_fp8_matmul(
  265. selected_hidden_states,
  266. self.gate_up_proj if self.has_gate else self.up_proj,
  267. self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv,
  268. block_size=self.block_size,
  269. expert_ids=expert_ids,
  270. ) # (S, 2 * intermediate_dim) or (S, intermediate_dim) depending on gating
  271. # Apply gating or activation
  272. if self.has_gate:
  273. # for gated experts we apply the custom/default gating mechanism
  274. proj_out = self._apply_gate(proj_out) # (S, intermediate_dim)
  275. else:
  276. # for non-gated experts we just apply the activation function
  277. proj_out = self.act_fn(proj_out) # (S, intermediate_dim)
  278. # --- Down projection per expert (FP8 batched) ---
  279. proj_out = triton_batched_fp8_matmul(
  280. proj_out,
  281. self.down_proj,
  282. self.down_proj_scale_inv,
  283. block_size=self.block_size,
  284. expert_ids=expert_ids,
  285. ) # (S, hidden_dim)
  286. # Apply routing weights
  287. weighted_out = proj_out * sample_weights.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim)
  288. # Accumulate results using deterministic reshape+sum instead of index_add_
  289. # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd)
  290. final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1)
  291. return final_hidden_states.to(hidden_states.dtype)
  292. def fp8_grouped_mm_experts_forward(
  293. self: torch.nn.Module,
  294. hidden_states: torch.Tensor,
  295. top_k_index: torch.Tensor,
  296. top_k_weights: torch.Tensor,
  297. ) -> torch.Tensor:
  298. if self.activation_scheme == "static":
  299. raise NotImplementedError(
  300. "grouped_mm experts dispatch does not support activation_scheme='static'. "
  301. "Use the default eager dispatch or switch to activation_scheme='dynamic'."
  302. )
  303. _load_triton_kernel()
  304. global triton_grouped_fp8_matmul
  305. device = hidden_states.device
  306. num_top_k = top_k_index.size(-1)
  307. num_tokens = hidden_states.size(0)
  308. hidden_dim = hidden_states.size(-1)
  309. # S is the number of selected token-expert pairs (S = num_tokens * num_top_k)
  310. token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,)
  311. sample_weights = top_k_weights.reshape(-1) # (S,)
  312. expert_ids = top_k_index.reshape(-1) # (S,)
  313. # Sort by expert for grouped processing
  314. perm = torch.argsort(expert_ids)
  315. inv_perm = torch.empty_like(perm)
  316. inv_perm[perm] = torch.arange(perm.size(0), device=device)
  317. expert_ids_g = expert_ids[perm]
  318. sample_weights_g = sample_weights[perm]
  319. selected_hidden_states_g = hidden_states[token_idx[perm]]
  320. # Compute offsets for grouped processing.
  321. # histc instead of bincount avoids cuda-graph issues;
  322. # CPU requires float input, CUDA requires int input (deterministic mode).
  323. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int()
  324. tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1)
  325. offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32)
  326. # --- Up projection per expert (FP8 grouped) ---
  327. proj_out = triton_grouped_fp8_matmul(
  328. selected_hidden_states_g,
  329. self.gate_up_proj if self.has_gate else self.up_proj,
  330. self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv,
  331. tokens_per_expert=tokens_per_expert,
  332. block_size=self.block_size,
  333. offsets=offsets,
  334. ) # (S, 2 * intermediate_dim)
  335. # Apply gating or activation
  336. if self.has_gate:
  337. # for gated experts we apply the custom/default gating mechanism
  338. proj_out = self._apply_gate(proj_out) # (S, intermediate_dim)
  339. else:
  340. # for non-gated experts we just apply the activation function
  341. proj_out = self.act_fn(proj_out) # (S, intermediate_dim)
  342. # --- Down projection per expert (FP8 grouped) ---
  343. proj_out = triton_grouped_fp8_matmul(
  344. proj_out,
  345. self.down_proj,
  346. self.down_proj_scale_inv,
  347. tokens_per_expert=tokens_per_expert,
  348. block_size=self.block_size,
  349. offsets=offsets,
  350. ) # (S, hidden_dim)
  351. # Apply routing weights
  352. weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim)
  353. # Restore original order
  354. weighted_out = weighted_out[inv_perm]
  355. # Accumulate results using deterministic reshape+sum instead of index_add_
  356. # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd)
  357. final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1)
  358. return final_hidden_states.to(hidden_states.dtype)
  359. def _build_deepgemm_contiguous_layout(expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int) -> tuple:
  360. """Build a TMA-aligned contiguous layout for DeepGEMM grouped GEMM.
  361. DeepGEMM requires M-dimension alignment per expert for TMA. This computes
  362. the mapping from sorted token positions to padded row positions, and the
  363. layout tensor that DeepGEMM uses to identify expert boundaries.
  364. Returns:
  365. sorted_to_padded: (num_tokens,) index map from sorted position to padded row
  366. grouped_layout: expert layout tensor (format depends on GPU architecture)
  367. total_padded_rows: total number of rows including alignment padding
  368. """
  369. device = expert_ids_sorted.device
  370. num_tokens = expert_ids_sorted.size(0)
  371. tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long()
  372. aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment
  373. # Upper bound avoids GPU→CPU sync; padding rows are skipped by DeepGEMM.
  374. total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1)
  375. padding_per_expert = aligned_tokens_per_expert - tokens_per_expert
  376. cumulative_padding = padding_per_expert.cumsum(0) - padding_per_expert
  377. sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted]
  378. if torch.cuda.get_device_capability(device)[0] >= 10: # Blackwell (SM100+)
  379. grouped_layout = tokens_per_expert.cumsum(0).int()
  380. else:
  381. # Hopper: per-row expert id, -1 for padding rows
  382. grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32)
  383. grouped_layout[sorted_to_padded] = expert_ids_sorted.int()
  384. return sorted_to_padded, grouped_layout, total_padded_rows
  385. def _pad_to_deepgemm_contiguous_layout(
  386. hidden_states: torch.Tensor,
  387. scales: torch.Tensor,
  388. sorted_to_padded: torch.Tensor,
  389. total_padded_rows: int,
  390. ) -> tuple[torch.Tensor, torch.Tensor]:
  391. """Pad sorted hidden states and scales into the TMA-aligned contiguous layout."""
  392. hidden_padded = torch.zeros(
  393. total_padded_rows, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype
  394. )
  395. hidden_padded[sorted_to_padded] = hidden_states
  396. scales_padded = torch.zeros(total_padded_rows, scales.shape[1], device=hidden_states.device, dtype=torch.float32)
  397. scales_padded[sorted_to_padded] = scales
  398. return hidden_padded, scales_padded
  399. def _unpad_from_deepgemm_contiguous_layout(
  400. hidden_states_padded: torch.Tensor, sorted_to_padded: torch.Tensor
  401. ) -> torch.Tensor:
  402. """Remove padding rows from the TMA-aligned contiguous layout."""
  403. return hidden_states_padded[sorted_to_padded]
  404. def fp8_deepgemm_experts_forward(
  405. self: torch.nn.Module,
  406. hidden_states: torch.Tensor,
  407. top_k_index: torch.Tensor,
  408. top_k_weights: torch.Tensor,
  409. ) -> torch.Tensor:
  410. if self.activation_scheme == "static":
  411. raise NotImplementedError(
  412. "deepgemm experts dispatch does not support activation_scheme='static'. "
  413. "Use the default eager dispatch or switch to activation_scheme='dynamic'."
  414. )
  415. if self.block_size is None:
  416. raise ValueError(
  417. "DeepGEMM requires block-wise quantization (block_size=[128, 128]), "
  418. "but got per-tensor quantization (block_size=None)."
  419. )
  420. if self.block_size[0] != 128 or self.block_size[1] != 128:
  421. raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}")
  422. _load_deepgemm_kernel()
  423. global deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8
  424. device = hidden_states.device
  425. num_top_k = top_k_index.size(-1)
  426. num_tokens = hidden_states.size(0)
  427. hidden_dim = hidden_states.size(-1)
  428. # S is the number of selected token-expert pairs (S = num_tokens * num_top_k)
  429. token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,)
  430. sample_weights = top_k_weights.reshape(-1) # (S,)
  431. expert_ids = top_k_index.reshape(-1) # (S,)
  432. # Sort by expert for grouped processing
  433. perm = torch.argsort(expert_ids)
  434. inv_perm = torch.empty_like(perm)
  435. inv_perm[perm] = torch.arange(perm.size(0), device=device)
  436. expert_ids_g = expert_ids[perm]
  437. sample_weights_g = sample_weights[perm]
  438. selected_hidden_states_g = hidden_states[token_idx[perm]]
  439. # Build TMA-aligned contiguous layout for DeepGEMM
  440. sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout(
  441. expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT
  442. )
  443. # --- Up projection per expert (DeepGEMM grouped contiguous) ---
  444. w_up = self.gate_up_proj if self.has_gate else self.up_proj
  445. ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv
  446. act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False)
  447. act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows)
  448. proj_out = torch.zeros(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16)
  449. use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10
  450. deepgemm_grouped_fp8_matmul(
  451. (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout
  452. )
  453. # Apply gating or activation
  454. if self.has_gate:
  455. proj_out = self._apply_gate(proj_out)
  456. else:
  457. proj_out = self.act_fn(proj_out)
  458. # --- Down projection per expert (DeepGEMM grouped contiguous) ---
  459. w_down = self.down_proj
  460. ws_down = self.down_proj_scale_inv
  461. proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False)
  462. proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16)
  463. deepgemm_grouped_fp8_matmul(
  464. (proj_fp8, proj_scales), (w_down, ws_down.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout
  465. )
  466. # Remove padding rows
  467. proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded)
  468. # Apply routing weights
  469. weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim)
  470. # Restore original order
  471. weighted_out = weighted_out[inv_perm]
  472. # Accumulate results using deterministic reshape+sum instead of index_add_
  473. # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd)
  474. final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1)
  475. return final_hidden_states.to(hidden_states.dtype)
  476. class FP8Experts(nn.Module):
  477. def __init__(
  478. self,
  479. config,
  480. block_size: tuple[int, int] | None = None,
  481. activation_scheme: str = "dynamic",
  482. has_bias: bool = False,
  483. has_gate: bool = True,
  484. dtype=_FP8_DTYPE,
  485. ):
  486. super().__init__()
  487. assert has_bias is False, (
  488. "FP8Experts does not support bias for now, please open an issue if you want this feature"
  489. )
  490. self.config = config
  491. self.has_bias = has_bias
  492. self.has_gate = has_gate
  493. self.block_size = block_size
  494. self.hidden_dim = config.hidden_size
  495. self.activation_scheme = activation_scheme
  496. self.num_experts = getattr(config, "num_local_experts", config.num_experts)
  497. self.intermediate_dim = getattr(config, "moe_intermediate_size", config.intermediate_size)
  498. self.act_fn = ACT2FN[getattr(config, "hidden_activation", config.hidden_act)]
  499. if self.has_gate:
  500. gu_proj_out, gu_proj_in = 2 * self.intermediate_dim, self.hidden_dim
  501. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, gu_proj_out, gu_proj_in, dtype=dtype))
  502. gu_scale_out = triton.cdiv(gu_proj_out, self.block_size[0]) if self.block_size is not None else 1
  503. gu_scale_in = triton.cdiv(gu_proj_in, self.block_size[1]) if self.block_size is not None else 1
  504. self.gate_up_proj_scale_inv = nn.Parameter(
  505. torch.empty(self.num_experts, gu_scale_out, gu_scale_in, dtype=torch.float32)
  506. )
  507. self.register_parameter("gate_up_proj_bias", None)
  508. else:
  509. u_proj_out, u_proj_in = self.intermediate_dim, self.hidden_dim
  510. self.up_proj = nn.Parameter(torch.empty(self.num_experts, u_proj_out, u_proj_in, dtype=dtype))
  511. u_scale_out = triton.cdiv(u_proj_out, self.block_size[0]) if self.block_size is not None else 1
  512. u_scale_in = triton.cdiv(u_proj_in, self.block_size[1]) if self.block_size is not None else 1
  513. self.up_proj_scale_inv = nn.Parameter(
  514. torch.empty(self.num_experts, u_scale_out, u_scale_in, dtype=torch.float32)
  515. )
  516. self.register_parameter("up_proj_bias", None)
  517. d_proj_out, d_proj_in = self.hidden_dim, self.intermediate_dim
  518. self.down_proj = nn.Parameter(torch.empty(self.num_experts, d_proj_out, d_proj_in, dtype=dtype))
  519. d_scale_out = triton.cdiv(d_proj_out, self.block_size[0]) if self.block_size is not None else 1
  520. d_scale_in = triton.cdiv(d_proj_in, self.block_size[1]) if self.block_size is not None else 1
  521. self.down_proj_scale_inv = nn.Parameter(
  522. torch.empty(self.num_experts, d_scale_out, d_scale_in, dtype=torch.float32)
  523. )
  524. self.register_parameter("down_proj_bias", None)
  525. if self.activation_scheme == "static":
  526. self.gate_up_proj_activation_scale = nn.Parameter(torch.ones(self.num_experts, dtype=torch.float32))
  527. self.down_proj_activation_scale = nn.Parameter(torch.ones(self.num_experts, dtype=torch.float32))
  528. def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
  529. gate, up = gate_up.chunk(2, dim=-1)
  530. return self.act_fn(gate) * up
  531. def forward(
  532. self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
  533. ) -> torch.Tensor:
  534. # index_add_ will accumulate using the dtype of the tensor we write into
  535. # so we use float32 for the accumulation to avoid numerical issues in bf16/fp16
  536. final_hidden_states = torch.zeros_like(hidden_states, dtype=torch.float32)
  537. with torch.no_grad():
  538. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  539. expert_mask = expert_mask.permute(2, 1, 0)
  540. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).view(-1)
  541. for expert_idx in expert_hit:
  542. if expert_idx == self.num_experts:
  543. continue
  544. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  545. current_state = hidden_states[token_idx]
  546. gate_up_act_scale = (
  547. self.gate_up_proj_activation_scale[expert_idx] if self.activation_scheme == "static" else None
  548. )
  549. proj_out = self.linear(
  550. current_state,
  551. self.gate_up_proj[expert_idx] if self.has_gate else self.up_proj[expert_idx],
  552. self.gate_up_proj_scale_inv[expert_idx] if self.has_gate else self.up_proj_scale_inv[expert_idx],
  553. activation_scale=gate_up_act_scale,
  554. )
  555. proj_out = self._apply_gate(proj_out) if self.has_gate else self.act_fn(proj_out)
  556. down_act_scale = (
  557. self.down_proj_activation_scale[expert_idx] if self.activation_scheme == "static" else None
  558. )
  559. proj_out = self.linear(
  560. proj_out,
  561. self.down_proj[expert_idx],
  562. self.down_proj_scale_inv[expert_idx],
  563. activation_scale=down_act_scale,
  564. )
  565. routing_weights = top_k_weights[token_idx, top_k_pos, None]
  566. weighted_out = proj_out * routing_weights.to(proj_out.dtype)
  567. final_hidden_states.index_add_(0, token_idx, weighted_out.to(final_hidden_states.dtype))
  568. return final_hidden_states.to(hidden_states.dtype)
  569. def linear(
  570. self,
  571. input: torch.Tensor,
  572. weight: torch.Tensor,
  573. weight_scale_inv: torch.Tensor,
  574. activation_scale: torch.Tensor | None = None,
  575. ) -> torch.Tensor:
  576. if weight.element_size() > 1:
  577. return F.linear(input, weight, None)
  578. if self.activation_scheme == "static" and activation_scale is not None:
  579. scale = activation_scale.to(torch.float32)
  580. qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
  581. else:
  582. _load_triton_kernel()
  583. global triton_fp8_act_quant
  584. qinput, scale = triton_fp8_act_quant(
  585. input, self.block_size[1] if self.block_size is not None else input.shape[-1]
  586. )
  587. output = w8a8_fp8_matmul(
  588. qinput,
  589. weight,
  590. scale,
  591. weight_scale_inv,
  592. self.block_size,
  593. output_dtype=input.dtype,
  594. )
  595. return output.to(dtype=input.dtype)
  596. class FP8ExpertsInterface(ExpertsInterface):
  597. """Interface for registering custom FP8 experts forward functions."""
  598. _global_mapping = {
  599. "batched_mm": fp8_batched_mm_experts_forward,
  600. "grouped_mm": fp8_grouped_mm_experts_forward,
  601. "deepgemm": fp8_deepgemm_experts_forward,
  602. }
  603. ALL_FP8_EXPERTS_FUNCTIONS = FP8ExpertsInterface()
  604. def replace_with_fp8_linear(
  605. model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False
  606. ):
  607. """
  608. A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules.
  609. Parameters:
  610. model (`torch.nn.Module`):
  611. Input model or `torch.nn.Module` as the function is run recursively.
  612. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
  613. Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons.
  614. quantization_config (`FbgemmFp8Config`):
  615. The quantization config object that contains the quantization parameters.
  616. pre_quantized (`book`, defaults to `False`):
  617. Whether the model is pre-quantized or not
  618. """
  619. if quantization_config.dequantize:
  620. return model
  621. has_been_replaced = False
  622. for module_name, module in model.named_modules():
  623. if not should_convert_module(module_name, modules_to_not_convert):
  624. continue
  625. # we need this to correctly materialize the weights during quantization
  626. module_kwargs = {} if pre_quantized else {"dtype": None}
  627. new_module = None
  628. with torch.device("meta"):
  629. if module_name.endswith(".experts"):
  630. has_gate = getattr(module, "has_gate", True)
  631. has_bias = getattr(module, "has_bias", False)
  632. config = getattr(module, "config", model.config.get_text_config())
  633. new_class = use_experts_implementation(
  634. experts_class=FP8Experts,
  635. experts_interface=ALL_FP8_EXPERTS_FUNCTIONS,
  636. has_bias=has_bias,
  637. has_gate=has_gate,
  638. )
  639. new_module = new_class(
  640. config=config,
  641. block_size=quantization_config.weight_block_size,
  642. activation_scheme=quantization_config.activation_scheme,
  643. has_bias=has_bias,
  644. has_gate=has_gate,
  645. **module_kwargs,
  646. )
  647. elif isinstance(module, nn.Linear):
  648. new_module = FP8Linear(
  649. in_features=module.in_features,
  650. out_features=module.out_features,
  651. block_size=quantization_config.weight_block_size,
  652. activation_scheme=quantization_config.activation_scheme,
  653. has_bias=module.bias is not None,
  654. **module_kwargs,
  655. )
  656. if new_module is not None:
  657. model.set_submodule(module_name, new_module)
  658. has_been_replaced = True
  659. if not has_been_replaced:
  660. logger.warning(
  661. "You are loading your model using fp8 but no linear modules were found in your model."
  662. " Please double check your model architecture."
  663. )
  664. return model
  665. class Fp8Quantize(ConversionOps):
  666. """
  667. A quantization operation that creates two tensors, weight and scale out of a weight.
  668. """
  669. def __init__(self, hf_quantizer):
  670. self.hf_quantizer = hf_quantizer
  671. def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
  672. # Unpack single key/value (value may be wrapped in a list)
  673. target_keys, value = tuple(input_dict.items())[0]
  674. value = value[0]
  675. # Resolve block size (support dict-like or attr-like quant_config)
  676. block_size = None
  677. if self.hf_quantizer.quantization_config is not None:
  678. if isinstance(self.hf_quantizer.quantization_config, dict):
  679. block_size = self.hf_quantizer.quantization_config.get("weight_block_size")
  680. else:
  681. block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None)
  682. if block_size is None:
  683. block_size = (value.shape[-2], value.shape[-1])
  684. block_m, block_n = block_size
  685. rows, cols = value.shape[-2], value.shape[-1]
  686. # Enforce exact tiling like your original
  687. if rows % block_m != 0 or cols % block_n != 0:
  688. raise ValueError(
  689. f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
  690. )
  691. # Leading dims can be empty (2D) or include num_experts/... (3D+)
  692. leading_shape = value.shape[:-2]
  693. rows_tiles = rows // block_m
  694. cols_tiles = cols // block_n
  695. original_shape = value.shape
  696. value_fp32 = value.to(torch.float32)
  697. # Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
  698. reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
  699. # Per-tile max-abs over the block dims
  700. # dims: block_m is at -3, block_n is at -1 after the reshape
  701. max_abs = reshaped.abs().amax(dim=(-3, -1))
  702. safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
  703. # Tile scale (we store inverse scale like your Linear: weight_scale_inv)
  704. scales = _FP8_MAX / safe_max_abs
  705. scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
  706. # Broadcast scales back over the block dims and quantize
  707. # max_abs/scales shape: (..., rows_tiles, cols_tiles)
  708. scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
  709. scaled = reshaped * scales_broadcast
  710. quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
  711. quantized = quantized.reshape(original_shape)
  712. inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
  713. if target_keys.endswith("weight"):
  714. scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
  715. else:
  716. scale_key = target_keys + "_scale_inv"
  717. # Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
  718. return {
  719. target_keys: quantized,
  720. scale_key: inv_scales,
  721. }
  722. class Fp8Dequantize(ConversionOps):
  723. """Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
  724. def __init__(self, hf_quantizer):
  725. self.hf_quantizer = hf_quantizer
  726. def convert(
  727. self,
  728. input_dict: dict[str, torch.Tensor],
  729. full_layer_name: str | None = None,
  730. **kwargs,
  731. ) -> dict[str, torch.Tensor]:
  732. if len(input_dict) < 2:
  733. # case where we only got weights, need to check for "weight$"
  734. return {full_layer_name: input_dict["weight$"]}
  735. quantized = input_dict["weight$"][0]
  736. scales = input_dict["weight_scale_inv"][0]
  737. rows, cols = quantized.shape[-2:]
  738. block_size = self.hf_quantizer.quantization_config.weight_block_size
  739. if block_size is None:
  740. block_size = (quantized.shape[-2], quantized.shape[-1])
  741. block_m, block_n = block_size
  742. if rows % block_m != 0 or cols % block_n != 0:
  743. raise ValueError(
  744. f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
  745. )
  746. quantized = quantized.to(scales.dtype)
  747. reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
  748. expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n)
  749. expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
  750. dequantized = reshaped * expanded_scales
  751. return {
  752. full_layer_name: dequantized.reshape(quantized.shape),
  753. }
  754. @property
  755. def reverse_op(self) -> "ConversionOps":
  756. return _IdentityOp()