| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909 |
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- import torch.nn as nn
- import triton
- from torch.nn import functional as F
- from ..activations import ACT2FN
- from ..core_model_loading import ConversionOps, _IdentityOp
- from ..quantizers.quantizers_utils import should_convert_module
- from ..utils import logging
- from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import
- from .hub_kernels import lazy_load_kernel
- from .moe import ExpertsInterface, use_experts_implementation
- logger = logging.get_logger(__name__)
- _FP8_DTYPE = torch.float8_e4m3fn
- _FP8_MIN = torch.finfo(_FP8_DTYPE).min
- _FP8_MAX = torch.finfo(_FP8_DTYPE).max
- # DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM
- # TMA is an H100 hardware addition that allows applications to asynchronously and
- # bi-directionally transfer 1D-5D tensors between GPU global and shared memory
- _DEEPGEMM_M_ALIGNMENT = 128
- # Lazily-loaded finegrained-fp8 Triton kernel functions (populated by _load_triton_kernel)
- triton_fp8_matmul = None
- triton_fp8_act_quant = None
- triton_batched_fp8_matmul = None
- triton_grouped_fp8_matmul = None
- # _triton_available: None = not yet attempted, True = loaded, False = failed (won't retry)
- _triton_available = None
- # Lazily-loaded DeepGEMM kernel functions (populated by _load_deepgemm_kernel)
- deepgemm_fp8_matmul = None
- deepgemm_grouped_fp8_matmul = None
- deepgemm_per_token_cast_to_fp8 = None
- # _deepgemm_available: None = not yet attempted, True = loaded, False = failed (won't retry)
- _deepgemm_available = None
- def _load_triton_kernel():
- """Lazily load the finegrained-fp8 Triton kernel and extract functions.
- Uses the hub kernels lazy loading pattern. Raises an error if the kernel
- cannot be loaded or required functions are missing. Only attempts loading once.
- """
- global \
- _triton_available, \
- triton_fp8_act_quant, \
- triton_fp8_matmul, \
- triton_batched_fp8_matmul, \
- triton_grouped_fp8_matmul
- if _triton_available is not None:
- if not _triton_available:
- raise ImportError("finegrained-fp8 kernel is not available (previous load attempt failed).")
- return
- _triton_available = False # mark attempted before any early exit
- kernel = lazy_load_kernel("finegrained-fp8")
- triton_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul")
- triton_fp8_act_quant = getattr(kernel, "fp8_act_quant")
- triton_batched_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_batched")
- triton_grouped_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped")
- missing = [
- name
- for name, attr in [
- ("w8a8_fp8_matmul", triton_fp8_matmul),
- ("fp8_act_quant", triton_fp8_act_quant),
- ("w8a8_fp8_matmul_batched", triton_batched_fp8_matmul),
- ("w8a8_fp8_matmul_grouped", triton_grouped_fp8_matmul),
- ]
- if attr is None
- ]
- if missing:
- raise ImportError(
- f"finegrained-fp8 kernel is missing required functions: {', '.join(missing)}. "
- "Please update the `kernels` package (`pip install -U kernels`)."
- )
- _triton_available = True
- def _load_deepgemm_kernel():
- """Lazily load the DeepGEMM kernel and extract functions with proper names.
- Uses the hub kernels lazy loading pattern. Raises an error if the kernel
- cannot be loaded, required functions are missing, or the hardware is insufficient.
- Only attempts loading once.
- """
- global _deepgemm_available, deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8
- if _deepgemm_available is not None:
- if not _deepgemm_available:
- raise ImportError("DeepGEMM kernel is not available (previous load attempt failed).")
- return
- _deepgemm_available = False # mark attempted before any early exit
- # DeepGEMM requires CUDA and a compatible GPU
- if not torch.cuda.is_available():
- raise ImportError(
- "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`."
- )
- # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions
- major = torch.cuda.get_device_capability()[0]
- if major < 9:
- raise ImportError(
- f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device "
- f"has compute capability {major}.x. Use a different `experts_implementation`."
- )
- # DeepGEMM requires CUDA runtime ≥ 12.3.
- cuda_major, cuda_minor = get_cuda_runtime_version()
- if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3):
- raise ImportError(
- f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. "
- "Please upgrade your CUDA toolkit or use a different `experts_implementation`."
- )
- kernel = lazy_load_kernel("deep-gemm")
- deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt")
- deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous")
- deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8")
- missing = [
- name
- for name, attr in [
- ("fp8_gemm_nt", deepgemm_fp8_matmul),
- ("m_grouped_fp8_gemm_nt_contiguous", deepgemm_grouped_fp8_matmul),
- ("utils.per_token_cast_to_fp8", deepgemm_per_token_cast_to_fp8),
- ]
- if attr is None
- ]
- if missing:
- raise ImportError(
- f"DeepGEMM kernel is missing required functions: {', '.join(missing)}. "
- "Please update the `kernels` package (`pip install -U kernels`)."
- )
- _deepgemm_available = True
- def w8a8_fp8_matmul(
- A: torch.Tensor,
- B: torch.Tensor,
- As: torch.Tensor,
- Bs: torch.Tensor,
- block_size: list[int],
- output_dtype: torch.dtype = torch.float32,
- ) -> torch.Tensor:
- """FP8 matmul: C = dequant(A, As) @ dequant(B, Bs)^T.
- Supports both per-tensor and block-wise quantization:
- - block_size=None or block_size=[N, K]: per-tensor mode (As is scalar/per-row, Bs is scalar)
- - block_size=[block_n, block_k]: block-wise mode (As and Bs are per-block scale grids)
- Dispatch order:
- 1. DeepGEMM (Hopper+, block_size 128x128) if available
- 2. Triton finegrained-fp8 kernel (universal fallback)
- Args:
- A: (M, K) float8_e4m3fn — quantized activations
- B: (N, K) float8_e4m3fn — quantized weights
- As: block-wise: (M, K//block_k) float32; per-tensor: (M,) per-row scales
- Bs: block-wise: (N//block_n, K//block_k) float32; per-tensor: scalar or (1,) single weight scale
- block_size: [block_n, block_k] for block-wise quantization, or None/[N, K] for per-tensor
- output_dtype: desired output dtype
- """
- if block_size is not None and block_size[0] == block_size[1] == 128:
- try:
- _load_deepgemm_kernel()
- global deepgemm_fp8_matmul
- except ImportError:
- logger.warning_once(
- "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. "
- "To use DeepGEMM FP8 matmul, ensure you have a Hopper (SM90+) or newer GPU with CUDA runtime 12.3+, "
- "and that the `kernels` package is installed and up to date (`pip install -U kernels`)."
- )
- else:
- # 3-6x faster than Triton
- A_2d = A.view(-1, A.shape[-1])
- As_2d = As.view(-1, As.shape[-1])
- output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype)
- deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output)
- return output.view(A.shape[:-1] + (B.shape[0],))
- _load_triton_kernel()
- global triton_fp8_matmul
- return triton_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
- class FP8Linear(nn.Linear):
- def __init__(
- self,
- in_features: int,
- out_features: int,
- block_size: tuple[int, int] | None = None,
- activation_scheme: str = "dynamic",
- has_bias: bool = False,
- dtype=_FP8_DTYPE,
- ):
- super().__init__(in_features, out_features)
- self.has_bias = has_bias
- self.block_size = block_size
- self.activation_scheme = activation_scheme
- self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
- if self.block_size is None:
- # If block size is None, it means that we are doing per-tensor quantization
- self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
- else:
- scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0]
- scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1]
- self.weight_scale_inv = nn.Parameter(
- torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
- )
- if self.activation_scheme == "static":
- self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
- else:
- self.register_parameter("activation_scale", None)
- if self.has_bias:
- self.bias = nn.Parameter(torch.empty(self.out_features))
- else:
- self.register_parameter("bias", None)
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- if self.weight.element_size() > 1:
- return F.linear(input, self.weight, self.bias)
- if isinstance(self.weight, torch.distributed.tensor.DTensor):
- weight = self.weight._local_tensor.contiguous()
- scale_inv = self.weight_scale_inv._local_tensor.contiguous()
- else:
- # why wouldn't it be contiguous?
- weight = self.weight.contiguous()
- scale_inv = self.weight_scale_inv.contiguous()
- if self.activation_scheme == "dynamic":
- _load_triton_kernel()
- global triton_fp8_act_quant
- qinput, scale = triton_fp8_act_quant(
- input, self.block_size[1] if self.block_size is not None else input.shape[-1]
- )
- elif self.activation_scheme == "static":
- scale = self.activation_scale.to(torch.float32)
- qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
- else:
- raise NotImplementedError(f"Unsupported activation scheme: {self.activation_scheme}")
- output = w8a8_fp8_matmul(
- qinput,
- weight,
- scale,
- scale_inv,
- self.block_size,
- output_dtype=input.dtype,
- )
- if self.bias is not None:
- output = output + self.bias
- return output.to(dtype=input.dtype)
- def fp8_batched_mm_experts_forward(
- self: torch.nn.Module,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
- ) -> torch.Tensor:
- if self.activation_scheme == "static":
- raise NotImplementedError(
- "batched_mm experts dispatch does not support activation_scheme='static'. "
- "Use the default eager dispatch or switch to activation_scheme='dynamic'."
- )
- _load_triton_kernel()
- global triton_batched_fp8_matmul
- device = hidden_states.device
- num_top_k = top_k_index.size(-1)
- num_tokens = hidden_states.size(0)
- hidden_dim = hidden_states.size(-1)
- # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k)
- token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,)
- sample_weights = top_k_weights.reshape(-1) # (S,)
- expert_ids = top_k_index.reshape(-1) # (S,)
- # Get current hidden states for selected samples
- selected_hidden_states = hidden_states[token_idx]
- # --- Up projection per expert (FP8 batched) ---
- proj_out = triton_batched_fp8_matmul(
- selected_hidden_states,
- self.gate_up_proj if self.has_gate else self.up_proj,
- self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv,
- block_size=self.block_size,
- expert_ids=expert_ids,
- ) # (S, 2 * intermediate_dim) or (S, intermediate_dim) depending on gating
- # Apply gating or activation
- if self.has_gate:
- # for gated experts we apply the custom/default gating mechanism
- proj_out = self._apply_gate(proj_out) # (S, intermediate_dim)
- else:
- # for non-gated experts we just apply the activation function
- proj_out = self.act_fn(proj_out) # (S, intermediate_dim)
- # --- Down projection per expert (FP8 batched) ---
- proj_out = triton_batched_fp8_matmul(
- proj_out,
- self.down_proj,
- self.down_proj_scale_inv,
- block_size=self.block_size,
- expert_ids=expert_ids,
- ) # (S, hidden_dim)
- # Apply routing weights
- weighted_out = proj_out * sample_weights.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim)
- # Accumulate results using deterministic reshape+sum instead of index_add_
- # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd)
- final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1)
- return final_hidden_states.to(hidden_states.dtype)
- def fp8_grouped_mm_experts_forward(
- self: torch.nn.Module,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
- ) -> torch.Tensor:
- if self.activation_scheme == "static":
- raise NotImplementedError(
- "grouped_mm experts dispatch does not support activation_scheme='static'. "
- "Use the default eager dispatch or switch to activation_scheme='dynamic'."
- )
- _load_triton_kernel()
- global triton_grouped_fp8_matmul
- device = hidden_states.device
- num_top_k = top_k_index.size(-1)
- num_tokens = hidden_states.size(0)
- hidden_dim = hidden_states.size(-1)
- # S is the number of selected token-expert pairs (S = num_tokens * num_top_k)
- token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,)
- sample_weights = top_k_weights.reshape(-1) # (S,)
- expert_ids = top_k_index.reshape(-1) # (S,)
- # Sort by expert for grouped processing
- perm = torch.argsort(expert_ids)
- inv_perm = torch.empty_like(perm)
- inv_perm[perm] = torch.arange(perm.size(0), device=device)
- expert_ids_g = expert_ids[perm]
- sample_weights_g = sample_weights[perm]
- selected_hidden_states_g = hidden_states[token_idx[perm]]
- # Compute offsets for grouped processing.
- # histc instead of bincount avoids cuda-graph issues;
- # CPU requires float input, CUDA requires int input (deterministic mode).
- histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int()
- tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1)
- offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32)
- # --- Up projection per expert (FP8 grouped) ---
- proj_out = triton_grouped_fp8_matmul(
- selected_hidden_states_g,
- self.gate_up_proj if self.has_gate else self.up_proj,
- self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv,
- tokens_per_expert=tokens_per_expert,
- block_size=self.block_size,
- offsets=offsets,
- ) # (S, 2 * intermediate_dim)
- # Apply gating or activation
- if self.has_gate:
- # for gated experts we apply the custom/default gating mechanism
- proj_out = self._apply_gate(proj_out) # (S, intermediate_dim)
- else:
- # for non-gated experts we just apply the activation function
- proj_out = self.act_fn(proj_out) # (S, intermediate_dim)
- # --- Down projection per expert (FP8 grouped) ---
- proj_out = triton_grouped_fp8_matmul(
- proj_out,
- self.down_proj,
- self.down_proj_scale_inv,
- tokens_per_expert=tokens_per_expert,
- block_size=self.block_size,
- offsets=offsets,
- ) # (S, hidden_dim)
- # Apply routing weights
- weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim)
- # Restore original order
- weighted_out = weighted_out[inv_perm]
- # Accumulate results using deterministic reshape+sum instead of index_add_
- # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd)
- final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1)
- return final_hidden_states.to(hidden_states.dtype)
- def _build_deepgemm_contiguous_layout(expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int) -> tuple:
- """Build a TMA-aligned contiguous layout for DeepGEMM grouped GEMM.
- DeepGEMM requires M-dimension alignment per expert for TMA. This computes
- the mapping from sorted token positions to padded row positions, and the
- layout tensor that DeepGEMM uses to identify expert boundaries.
- Returns:
- sorted_to_padded: (num_tokens,) index map from sorted position to padded row
- grouped_layout: expert layout tensor (format depends on GPU architecture)
- total_padded_rows: total number of rows including alignment padding
- """
- device = expert_ids_sorted.device
- num_tokens = expert_ids_sorted.size(0)
- tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long()
- aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment
- # Upper bound avoids GPU→CPU sync; padding rows are skipped by DeepGEMM.
- total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1)
- padding_per_expert = aligned_tokens_per_expert - tokens_per_expert
- cumulative_padding = padding_per_expert.cumsum(0) - padding_per_expert
- sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted]
- if torch.cuda.get_device_capability(device)[0] >= 10: # Blackwell (SM100+)
- grouped_layout = tokens_per_expert.cumsum(0).int()
- else:
- # Hopper: per-row expert id, -1 for padding rows
- grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32)
- grouped_layout[sorted_to_padded] = expert_ids_sorted.int()
- return sorted_to_padded, grouped_layout, total_padded_rows
- def _pad_to_deepgemm_contiguous_layout(
- hidden_states: torch.Tensor,
- scales: torch.Tensor,
- sorted_to_padded: torch.Tensor,
- total_padded_rows: int,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Pad sorted hidden states and scales into the TMA-aligned contiguous layout."""
- hidden_padded = torch.zeros(
- total_padded_rows, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype
- )
- hidden_padded[sorted_to_padded] = hidden_states
- scales_padded = torch.zeros(total_padded_rows, scales.shape[1], device=hidden_states.device, dtype=torch.float32)
- scales_padded[sorted_to_padded] = scales
- return hidden_padded, scales_padded
- def _unpad_from_deepgemm_contiguous_layout(
- hidden_states_padded: torch.Tensor, sorted_to_padded: torch.Tensor
- ) -> torch.Tensor:
- """Remove padding rows from the TMA-aligned contiguous layout."""
- return hidden_states_padded[sorted_to_padded]
- def fp8_deepgemm_experts_forward(
- self: torch.nn.Module,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
- ) -> torch.Tensor:
- if self.activation_scheme == "static":
- raise NotImplementedError(
- "deepgemm experts dispatch does not support activation_scheme='static'. "
- "Use the default eager dispatch or switch to activation_scheme='dynamic'."
- )
- if self.block_size is None:
- raise ValueError(
- "DeepGEMM requires block-wise quantization (block_size=[128, 128]), "
- "but got per-tensor quantization (block_size=None)."
- )
- if self.block_size[0] != 128 or self.block_size[1] != 128:
- raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}")
- _load_deepgemm_kernel()
- global deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8
- device = hidden_states.device
- num_top_k = top_k_index.size(-1)
- num_tokens = hidden_states.size(0)
- hidden_dim = hidden_states.size(-1)
- # S is the number of selected token-expert pairs (S = num_tokens * num_top_k)
- token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,)
- sample_weights = top_k_weights.reshape(-1) # (S,)
- expert_ids = top_k_index.reshape(-1) # (S,)
- # Sort by expert for grouped processing
- perm = torch.argsort(expert_ids)
- inv_perm = torch.empty_like(perm)
- inv_perm[perm] = torch.arange(perm.size(0), device=device)
- expert_ids_g = expert_ids[perm]
- sample_weights_g = sample_weights[perm]
- selected_hidden_states_g = hidden_states[token_idx[perm]]
- # Build TMA-aligned contiguous layout for DeepGEMM
- sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout(
- expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT
- )
- # --- Up projection per expert (DeepGEMM grouped contiguous) ---
- w_up = self.gate_up_proj if self.has_gate else self.up_proj
- ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv
- act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False)
- act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows)
- proj_out = torch.zeros(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16)
- use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10
- deepgemm_grouped_fp8_matmul(
- (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout
- )
- # Apply gating or activation
- if self.has_gate:
- proj_out = self._apply_gate(proj_out)
- else:
- proj_out = self.act_fn(proj_out)
- # --- Down projection per expert (DeepGEMM grouped contiguous) ---
- w_down = self.down_proj
- ws_down = self.down_proj_scale_inv
- proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False)
- proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16)
- deepgemm_grouped_fp8_matmul(
- (proj_fp8, proj_scales), (w_down, ws_down.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout
- )
- # Remove padding rows
- proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded)
- # Apply routing weights
- weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim)
- # Restore original order
- weighted_out = weighted_out[inv_perm]
- # Accumulate results using deterministic reshape+sum instead of index_add_
- # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd)
- final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1)
- return final_hidden_states.to(hidden_states.dtype)
- class FP8Experts(nn.Module):
- def __init__(
- self,
- config,
- block_size: tuple[int, int] | None = None,
- activation_scheme: str = "dynamic",
- has_bias: bool = False,
- has_gate: bool = True,
- dtype=_FP8_DTYPE,
- ):
- super().__init__()
- assert has_bias is False, (
- "FP8Experts does not support bias for now, please open an issue if you want this feature"
- )
- self.config = config
- self.has_bias = has_bias
- self.has_gate = has_gate
- self.block_size = block_size
- self.hidden_dim = config.hidden_size
- self.activation_scheme = activation_scheme
- self.num_experts = getattr(config, "num_local_experts", config.num_experts)
- self.intermediate_dim = getattr(config, "moe_intermediate_size", config.intermediate_size)
- self.act_fn = ACT2FN[getattr(config, "hidden_activation", config.hidden_act)]
- if self.has_gate:
- gu_proj_out, gu_proj_in = 2 * self.intermediate_dim, self.hidden_dim
- self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, gu_proj_out, gu_proj_in, dtype=dtype))
- gu_scale_out = triton.cdiv(gu_proj_out, self.block_size[0]) if self.block_size is not None else 1
- gu_scale_in = triton.cdiv(gu_proj_in, self.block_size[1]) if self.block_size is not None else 1
- self.gate_up_proj_scale_inv = nn.Parameter(
- torch.empty(self.num_experts, gu_scale_out, gu_scale_in, dtype=torch.float32)
- )
- self.register_parameter("gate_up_proj_bias", None)
- else:
- u_proj_out, u_proj_in = self.intermediate_dim, self.hidden_dim
- self.up_proj = nn.Parameter(torch.empty(self.num_experts, u_proj_out, u_proj_in, dtype=dtype))
- u_scale_out = triton.cdiv(u_proj_out, self.block_size[0]) if self.block_size is not None else 1
- u_scale_in = triton.cdiv(u_proj_in, self.block_size[1]) if self.block_size is not None else 1
- self.up_proj_scale_inv = nn.Parameter(
- torch.empty(self.num_experts, u_scale_out, u_scale_in, dtype=torch.float32)
- )
- self.register_parameter("up_proj_bias", None)
- d_proj_out, d_proj_in = self.hidden_dim, self.intermediate_dim
- self.down_proj = nn.Parameter(torch.empty(self.num_experts, d_proj_out, d_proj_in, dtype=dtype))
- d_scale_out = triton.cdiv(d_proj_out, self.block_size[0]) if self.block_size is not None else 1
- d_scale_in = triton.cdiv(d_proj_in, self.block_size[1]) if self.block_size is not None else 1
- self.down_proj_scale_inv = nn.Parameter(
- torch.empty(self.num_experts, d_scale_out, d_scale_in, dtype=torch.float32)
- )
- self.register_parameter("down_proj_bias", None)
- if self.activation_scheme == "static":
- self.gate_up_proj_activation_scale = nn.Parameter(torch.ones(self.num_experts, dtype=torch.float32))
- self.down_proj_activation_scale = nn.Parameter(torch.ones(self.num_experts, dtype=torch.float32))
- def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
- gate, up = gate_up.chunk(2, dim=-1)
- return self.act_fn(gate) * up
- def forward(
- self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
- ) -> torch.Tensor:
- # index_add_ will accumulate using the dtype of the tensor we write into
- # so we use float32 for the accumulation to avoid numerical issues in bf16/fp16
- final_hidden_states = torch.zeros_like(hidden_states, dtype=torch.float32)
- with torch.no_grad():
- expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
- expert_mask = expert_mask.permute(2, 1, 0)
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).view(-1)
- for expert_idx in expert_hit:
- if expert_idx == self.num_experts:
- continue
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
- current_state = hidden_states[token_idx]
- gate_up_act_scale = (
- self.gate_up_proj_activation_scale[expert_idx] if self.activation_scheme == "static" else None
- )
- proj_out = self.linear(
- current_state,
- self.gate_up_proj[expert_idx] if self.has_gate else self.up_proj[expert_idx],
- self.gate_up_proj_scale_inv[expert_idx] if self.has_gate else self.up_proj_scale_inv[expert_idx],
- activation_scale=gate_up_act_scale,
- )
- proj_out = self._apply_gate(proj_out) if self.has_gate else self.act_fn(proj_out)
- down_act_scale = (
- self.down_proj_activation_scale[expert_idx] if self.activation_scheme == "static" else None
- )
- proj_out = self.linear(
- proj_out,
- self.down_proj[expert_idx],
- self.down_proj_scale_inv[expert_idx],
- activation_scale=down_act_scale,
- )
- routing_weights = top_k_weights[token_idx, top_k_pos, None]
- weighted_out = proj_out * routing_weights.to(proj_out.dtype)
- final_hidden_states.index_add_(0, token_idx, weighted_out.to(final_hidden_states.dtype))
- return final_hidden_states.to(hidden_states.dtype)
- def linear(
- self,
- input: torch.Tensor,
- weight: torch.Tensor,
- weight_scale_inv: torch.Tensor,
- activation_scale: torch.Tensor | None = None,
- ) -> torch.Tensor:
- if weight.element_size() > 1:
- return F.linear(input, weight, None)
- if self.activation_scheme == "static" and activation_scale is not None:
- scale = activation_scale.to(torch.float32)
- qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
- else:
- _load_triton_kernel()
- global triton_fp8_act_quant
- qinput, scale = triton_fp8_act_quant(
- input, self.block_size[1] if self.block_size is not None else input.shape[-1]
- )
- output = w8a8_fp8_matmul(
- qinput,
- weight,
- scale,
- weight_scale_inv,
- self.block_size,
- output_dtype=input.dtype,
- )
- return output.to(dtype=input.dtype)
- class FP8ExpertsInterface(ExpertsInterface):
- """Interface for registering custom FP8 experts forward functions."""
- _global_mapping = {
- "batched_mm": fp8_batched_mm_experts_forward,
- "grouped_mm": fp8_grouped_mm_experts_forward,
- "deepgemm": fp8_deepgemm_experts_forward,
- }
- ALL_FP8_EXPERTS_FUNCTIONS = FP8ExpertsInterface()
- def replace_with_fp8_linear(
- model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False
- ):
- """
- A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules.
- Parameters:
- model (`torch.nn.Module`):
- Input model or `torch.nn.Module` as the function is run recursively.
- modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
- Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons.
- quantization_config (`FbgemmFp8Config`):
- The quantization config object that contains the quantization parameters.
- pre_quantized (`book`, defaults to `False`):
- Whether the model is pre-quantized or not
- """
- if quantization_config.dequantize:
- return model
- has_been_replaced = False
- for module_name, module in model.named_modules():
- if not should_convert_module(module_name, modules_to_not_convert):
- continue
- # we need this to correctly materialize the weights during quantization
- module_kwargs = {} if pre_quantized else {"dtype": None}
- new_module = None
- with torch.device("meta"):
- if module_name.endswith(".experts"):
- has_gate = getattr(module, "has_gate", True)
- has_bias = getattr(module, "has_bias", False)
- config = getattr(module, "config", model.config.get_text_config())
- new_class = use_experts_implementation(
- experts_class=FP8Experts,
- experts_interface=ALL_FP8_EXPERTS_FUNCTIONS,
- has_bias=has_bias,
- has_gate=has_gate,
- )
- new_module = new_class(
- config=config,
- block_size=quantization_config.weight_block_size,
- activation_scheme=quantization_config.activation_scheme,
- has_bias=has_bias,
- has_gate=has_gate,
- **module_kwargs,
- )
- elif isinstance(module, nn.Linear):
- new_module = FP8Linear(
- in_features=module.in_features,
- out_features=module.out_features,
- block_size=quantization_config.weight_block_size,
- activation_scheme=quantization_config.activation_scheme,
- has_bias=module.bias is not None,
- **module_kwargs,
- )
- if new_module is not None:
- model.set_submodule(module_name, new_module)
- has_been_replaced = True
- if not has_been_replaced:
- logger.warning(
- "You are loading your model using fp8 but no linear modules were found in your model."
- " Please double check your model architecture."
- )
- return model
- class Fp8Quantize(ConversionOps):
- """
- A quantization operation that creates two tensors, weight and scale out of a weight.
- """
- def __init__(self, hf_quantizer):
- self.hf_quantizer = hf_quantizer
- def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
- # Unpack single key/value (value may be wrapped in a list)
- target_keys, value = tuple(input_dict.items())[0]
- value = value[0]
- # Resolve block size (support dict-like or attr-like quant_config)
- block_size = None
- if self.hf_quantizer.quantization_config is not None:
- if isinstance(self.hf_quantizer.quantization_config, dict):
- block_size = self.hf_quantizer.quantization_config.get("weight_block_size")
- else:
- block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None)
- if block_size is None:
- block_size = (value.shape[-2], value.shape[-1])
- block_m, block_n = block_size
- rows, cols = value.shape[-2], value.shape[-1]
- # Enforce exact tiling like your original
- if rows % block_m != 0 or cols % block_n != 0:
- raise ValueError(
- f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
- )
- # Leading dims can be empty (2D) or include num_experts/... (3D+)
- leading_shape = value.shape[:-2]
- rows_tiles = rows // block_m
- cols_tiles = cols // block_n
- original_shape = value.shape
- value_fp32 = value.to(torch.float32)
- # Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
- reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
- # Per-tile max-abs over the block dims
- # dims: block_m is at -3, block_n is at -1 after the reshape
- max_abs = reshaped.abs().amax(dim=(-3, -1))
- safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
- # Tile scale (we store inverse scale like your Linear: weight_scale_inv)
- scales = _FP8_MAX / safe_max_abs
- scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
- # Broadcast scales back over the block dims and quantize
- # max_abs/scales shape: (..., rows_tiles, cols_tiles)
- scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
- scaled = reshaped * scales_broadcast
- quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
- quantized = quantized.reshape(original_shape)
- inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
- if target_keys.endswith("weight"):
- scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
- else:
- scale_key = target_keys + "_scale_inv"
- # Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
- return {
- target_keys: quantized,
- scale_key: inv_scales,
- }
- class Fp8Dequantize(ConversionOps):
- """Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
- def __init__(self, hf_quantizer):
- self.hf_quantizer = hf_quantizer
- def convert(
- self,
- input_dict: dict[str, torch.Tensor],
- full_layer_name: str | None = None,
- **kwargs,
- ) -> dict[str, torch.Tensor]:
- if len(input_dict) < 2:
- # case where we only got weights, need to check for "weight$"
- return {full_layer_name: input_dict["weight$"]}
- quantized = input_dict["weight$"][0]
- scales = input_dict["weight_scale_inv"][0]
- rows, cols = quantized.shape[-2:]
- block_size = self.hf_quantizer.quantization_config.weight_block_size
- if block_size is None:
- block_size = (quantized.shape[-2], quantized.shape[-1])
- block_m, block_n = block_size
- if rows % block_m != 0 or cols % block_n != 0:
- raise ValueError(
- f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
- )
- quantized = quantized.to(scales.dtype)
- reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
- expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n)
- expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
- dequantized = reshaped * expanded_scales
- return {
- full_layer_name: dequantized.reshape(quantized.shape),
- }
- @property
- def reverse_op(self) -> "ConversionOps":
- return _IdentityOp()
|