| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from ..utils import is_torch_available, logging
- if is_torch_available():
- import torch
- from torch import nn
- from contextlib import contextmanager
- from ..core_model_loading import ConversionOps, _IdentityOp
- from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
- logger = logging.get_logger(__name__)
- FP4_VALUES = [
- +0.0,
- +0.5,
- +1.0,
- +1.5,
- +2.0,
- +3.0,
- +4.0,
- +6.0,
- -0.0,
- -0.5,
- -1.0,
- -1.5,
- -2.0,
- -3.0,
- -4.0,
- -6.0,
- ]
- @contextmanager
- def on_device(dev):
- if is_torch_available():
- import torch
- if isinstance(dev, torch.Tensor):
- dev = dev.device
- elif isinstance(dev, str):
- dev = torch.device(dev)
- dev_type = getattr(dev, "type", None)
- if dev_type == "cuda":
- with torch.cuda.device(dev):
- yield
- return
- if dev_type == "xpu" and hasattr(torch, "xpu"):
- with torch.xpu.device(dev):
- yield
- return
- # other: CPU
- yield
- class Mxfp4Quantize(ConversionOps):
- def __init__(self, hf_quantizer):
- self.hf_quantizer = hf_quantizer
- def convert(
- self,
- input_dict: dict[str, torch.Tensor],
- model: torch.nn.Module | None = None,
- missing_keys: list[str] | None = None,
- full_layer_name: str | None = None,
- **kwargs,
- ) -> dict[str, torch.Tensor]:
- _, value = tuple(input_dict.items())[0]
- value = value[0] if isinstance(value, list) else value
- module, _ = get_module_from_name(model, full_layer_name)
- with torch.device(value.device):
- if isinstance(module, Mxfp4GptOssExperts):
- triton_weight_tensor, weight_scale = quantize_to_mxfp4(value.transpose(-1, -2), triton_kernels_hub)
- PrecisionConfig, FlexCtx, InFlexData = (
- triton_kernels_hub.matmul_ogs.PrecisionConfig,
- triton_kernels_hub.matmul_ogs.FlexCtx,
- triton_kernels_hub.matmul_ogs.InFlexData,
- )
- triton_weight_tensor, weight_scale = swizzle_mxfp4(
- triton_weight_tensor, weight_scale, triton_kernels_hub
- )
- proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
- if proj in module._parameters:
- # Remove the nn.Parameter registration so we can attach the Triton tensor
- del module._parameters[proj]
- setattr(module, proj, triton_weight_tensor)
- setattr(
- module,
- f"{proj}_precision_config",
- PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
- )
- missing_keys.discard(f"{full_layer_name}")
- module._is_hf_initialized = True
- return {}
- class Mxfp4Dequantize(ConversionOps):
- def __init__(self, hf_quantizer):
- self.hf_quantizer = hf_quantizer
- def convert(
- self,
- input_dict: dict[str, torch.Tensor],
- model: torch.nn.Module | None = None,
- full_layer_name: str | None = None,
- missing_keys: list[str] | None = None,
- **kwargs,
- ) -> dict[str, torch.Tensor]:
- param_data = {}
- proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
- if f"{proj}_blocks" in input_dict.keys():
- if isinstance(input_dict[f"{proj}_blocks"], list):
- param_data[f"{proj}_blocks"] = input_dict[f"{proj}_blocks"][0]
- else:
- param_data[f"{proj}_blocks"] = input_dict[f"{proj}_blocks"]
- if f"{proj}_scales" in input_dict.keys():
- if isinstance(input_dict[f"{proj}_scales"], list):
- param_data[f"{proj}_scales"] = input_dict[f"{proj}_scales"][0]
- else:
- param_data[f"{proj}_scales"] = input_dict[f"{proj}_scales"]
- # Here we are dequantizing the weights
- dequantized = dequantize_convertops(param_data[f"{proj}_blocks"], param_data[f"{proj}_scales"])
- return {full_layer_name: dequantized}
- @property
- def reverse_op(self) -> "ConversionOps":
- return _IdentityOp()
- class Mxfp4Deserialize(ConversionOps):
- def __init__(self, hf_quantizer):
- self.hf_quantizer = hf_quantizer
- def convert(
- self,
- input_dict: dict[str, torch.Tensor],
- model: torch.nn.Module | None = None,
- full_layer_name: str | None = None,
- missing_keys: list[str] | None = None,
- **kwargs,
- ) -> dict[str, torch.Tensor]:
- param_data = {}
- proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
- if f"{proj}_blocks" in input_dict.keys():
- if isinstance(input_dict[f"{proj}_blocks"], list):
- param_data[f"{proj}_blocks"] = input_dict[f"{proj}_blocks"][0]
- else:
- param_data[f"{proj}_blocks"] = input_dict[f"{proj}_blocks"]
- if f"{proj}_scales" in input_dict.keys():
- if isinstance(input_dict[f"{proj}_scales"], list):
- param_data[f"{proj}_scales"] = input_dict[f"{proj}_scales"][0]
- else:
- param_data[f"{proj}_scales"] = input_dict[f"{proj}_scales"]
- # Eagerly set tensors on the module and perform swizzle
- module, _ = get_module_from_name(model, full_layer_name)
- swizzle_mxfp4_convertops(
- param_data[f"{proj}_blocks"],
- param_data[f"{proj}_scales"],
- module,
- proj,
- param_data[f"{proj}_blocks"].device,
- triton_kernels_hub,
- )
- missing_keys.discard(f"{full_layer_name}")
- module._is_hf_initialized = True
- # We return an empty mapping since the module was updated in-place. This prevents
- # the loader from trying to materialize the original meta-parameter names again.
- # We don't use set_param_for_module since it expects mainly a torch.nn.Parameter or a safetensors pointer
- return {}
- @property
- def reverse_op(self) -> ConversionOps:
- return Mxfp4ReverseDeserialize(self.hf_quantizer)
- class Mxfp4ReverseDeserialize(ConversionOps):
- def __init__(self, hf_quantizer):
- self.hf_quantizer = hf_quantizer
- def convert(
- self,
- input_dict: dict[str, torch.Tensor],
- model: torch.nn.Module | None = None,
- full_layer_name: str | None = None,
- missing_keys: list[str] | None = None,
- **kwargs,
- ) -> dict[str, torch.Tensor]:
- num_local_experts = getattr(model.config, "num_local_experts", 32)
- hidden_size = getattr(model.config, "hidden_size", 2880)
- proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
- name = full_layer_name.rsplit("_", 1)[0]
- module, _ = get_module_from_name(model, full_layer_name)
- state_dict = {}
- if isinstance(module, Mxfp4GptOssExperts):
- if "bias" in full_layer_name:
- name = full_layer_name.replace("_blocks", "")
- state_dict[name] = getattr(module, proj + "_bias")
- return state_dict
- if "gate_up_proj" in full_layer_name:
- state_dict[f"{name}_blocks"] = (
- module.gate_up_proj.storage.layout.unswizzle_data(module.gate_up_proj.storage.data)
- .transpose(-1, -2)
- .reshape(num_local_experts, -1, 90, 16)
- )
- state_dict[f"{name}_scales"] = (
- module.gate_up_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
- module.gate_up_proj_precision_config.weight_scale.storage.data
- ).transpose(-1, -2)
- )
- else:
- state_dict[f"{name}_blocks"] = (
- module.down_proj.storage.layout.unswizzle_data(module.down_proj.storage.data)
- .transpose(-1, -2)
- .reshape(num_local_experts, hidden_size, 90, -1)
- )
- state_dict[f"{name}_scales"] = (
- module.down_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
- module.down_proj_precision_config.weight_scale.storage.data
- ).transpose(-1, -2)
- )
- return state_dict
- # Copied from GPT_OSS repo and vllm
- def quantize_to_mxfp4(w, triton_kernels_hub):
- downcast_to_mxfp_torch = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp_torch
- w, w_scale = downcast_to_mxfp_torch(w.to(torch.bfloat16), torch.uint8, axis=1)
- return w, w_scale
- def swizzle_mxfp4(w, w_scale, triton_kernels_hub):
- """
- Changes the layout of the tensors depending on the hardware
- """
- FP4, convert_layout, wrap_torch_tensor = (
- triton_kernels_hub.tensor.FP4,
- triton_kernels_hub.tensor.convert_layout,
- triton_kernels_hub.tensor.wrap_torch_tensor,
- )
- layout = triton_kernels_hub.tensor_details.layout
- StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
- value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
- w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
- w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
- return w, w_scale
- # Mostly copied from GPT_OSS repo
- # TODO: Add absolute link when the repo is public
- def _convert_moe_packed_tensors(
- blocks,
- scales,
- *,
- dtype: torch.dtype = torch.bfloat16,
- rows_per_chunk: int = 32768 * 1024, # TODO these values are not here by mistake ;)
- ) -> torch.Tensor:
- """
- Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
- pass of GPT_OSS.
- """
- import math
- blocks = blocks.to(torch.uint8)
- scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
- assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}"
- lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
- *prefix_shape, G, B = blocks.shape
- rows_total = math.prod(prefix_shape) * G
- blocks = blocks.reshape(rows_total, B)
- scales = scales.reshape(rows_total, 1)
- out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
- for r0 in range(0, rows_total, rows_per_chunk):
- r1 = min(r0 + rows_per_chunk, rows_total)
- blk = blocks[r0:r1]
- exp = scales[r0:r1]
- sub = out[r0:r1]
- # This vector is only used to index into `lut`, but is hugeee in GPU memory so we delete it immediately
- idx_lo = (blk & 0x0F).to(torch.int)
- sub[:, 0::2] = lut[idx_lo]
- del idx_lo
- # This vector is only used to index into `lut`, but is hugeee in GPU memory so we delete it immediately
- idx_hi = (blk >> 4).to(torch.int)
- sub[:, 1::2] = lut[idx_hi]
- del idx_hi
- # Perform op
- torch.ldexp(sub, exp, out=sub)
- del blk, exp, sub
- out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
- return out.transpose(1, 2).contiguous()
- def convert_moe_packed_tensors(
- blocks,
- scales,
- *,
- dtype: torch.dtype = torch.bfloat16,
- rows_per_chunk: int = 32768 * 1024, # TODO these values are not here by mistake ;)
- ) -> torch.Tensor:
- """
- Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
- pass of GPT_OSS.
- """
- # Since the intermediate ops requite A LOT of memory, in very constrained device_map="auto" settings
- # it may OOM, hence this wrapper and move back to cpu if needed
- # torch statistics are not accurate enough to estimate if we will have enough memory due to fragmentation and
- # in-place operation on non-contiguous tensors (may sometimes require more temporary copies)
- try:
- return _convert_moe_packed_tensors(blocks, scales, dtype=dtype, rows_per_chunk=rows_per_chunk)
- # In the case of OOM due to very tight device_map, we convert and return on cpu - it will then be put back on correct
- # devide with the accelerate dispatch (doing it right away may still lead to OOM, but more memory is available later)
- except torch.OutOfMemoryError:
- blocks = blocks.to("cpu")
- scales = scales.to("cpu")
- return _convert_moe_packed_tensors(blocks, scales, dtype=dtype, rows_per_chunk=rows_per_chunk)
- class Mxfp4GptOssExperts(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.num_experts = config.num_local_experts
- self.intermediate_size = config.intermediate_size
- self.hidden_size = config.hidden_size
- self.gate_up_proj = nn.Parameter(
- torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
- requires_grad=False,
- )
- self.gate_up_proj_bias = nn.Parameter(
- torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
- )
- self.down_proj = nn.Parameter(
- torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
- requires_grad=False,
- )
- self.down_proj_bias = nn.Parameter(
- torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
- )
- self.alpha = 1.702
- self.limit = getattr(config, "swiglu_limit", 7.0)
- self.gate_up_proj_precision_config = None
- self.down_proj_precision_config = None
- self.limit = getattr(config, "swiglu_limit", 7.0)
- def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
- FnSpecs, FusedActivation, matmul_ogs = (
- triton_kernels_hub.matmul_ogs.FnSpecs,
- triton_kernels_hub.matmul_ogs.FusedActivation,
- triton_kernels_hub.matmul_ogs.matmul_ogs,
- )
- swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
- with on_device(hidden_states.device):
- act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2)
- intermediate_cache1 = matmul_ogs(
- hidden_states,
- self.gate_up_proj,
- self.gate_up_proj_bias.to(torch.float32),
- routing_data,
- gather_indx=gather_idx,
- precision_config=self.gate_up_proj_precision_config,
- gammas=None,
- fused_activation=act,
- )
- intermediate_cache3 = matmul_ogs(
- intermediate_cache1,
- self.down_proj,
- self.down_proj_bias.to(torch.float32),
- routing_data,
- scatter_indx=scatter_idx,
- precision_config=self.down_proj_precision_config,
- gammas=routing_data.gate_scal,
- )
- return intermediate_cache3
- # Adapted from GPT_OSS repo
- # TODO: Add absolute link when the repo is public
- def routing_torch_dist(
- logits,
- n_expts_act,
- ):
- import os
- GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
- triton_kernels_hub.routing.GatherIndx,
- triton_kernels_hub.routing.RoutingData,
- triton_kernels_hub.routing.ScatterIndx,
- triton_kernels_hub.routing.compute_expt_data_torch,
- )
- with on_device(logits.device):
- world_size = torch.distributed.get_world_size()
- rank = int(os.environ.get("LOCAL_RANK", "0"))
- replace_value = -1
- n_tokens = logits.shape[0]
- n_expts_tot = logits.shape[1]
- n_local_experts = n_expts_tot // world_size
- local_expert_start = rank * n_local_experts
- local_expert_end = (rank + 1) * n_local_experts
- n_gates_pad = n_tokens * n_expts_act
- def topk(vals, k):
- tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
- tk_indx = tk_indx.long()
- tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
- return tk_val, tk_indx.int()
- expt_scal, expt_indx = topk(logits, n_expts_act)
- expt_scal = torch.softmax(expt_scal, dim=-1)
- expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
- expt_scal = torch.gather(expt_scal, 1, sort_indices)
- # Flatten and mask for local experts
- expt_scal = expt_scal.reshape(-1)
- hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start:local_expert_end]
- expt_indx = expt_indx.view(-1).to(torch.int32)
- # we use a large value to replace the indices that are not in the local expert range
- var = 1000
- expt_indx = torch.where(expt_indx < local_expert_start, var, expt_indx)
- topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32)
- gate_indx = torch.argsort(topk_indx).to(torch.int32)
- expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value)
- expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value)
- gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx)
- gate_scal = expt_scal[topk_indx]
- topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx)
- # # Routing metadata for local expert computation
- gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
- scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
- expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad)
- hit_experts = n_expts_act
- return RoutingData(gate_scal, hist, n_local_experts, hit_experts, expt_data), gather_indx, scatter_indx
- def mlp_forward(self, hidden_states):
- import torch.distributed as dist
- if dist.is_available() and dist.is_initialized() and hasattr(self, "_is_hooked"):
- routing = routing_torch_dist
- else:
- routing = triton_kernels_hub.routing.routing
- batch_size = hidden_states.shape[0]
- hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
- router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
- with on_device(router_logits.device):
- routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
- routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx=scatter_idx)
- routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
- return routed_out, router_logits
- def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
- from ..integrations.tensor_parallel import shard_and_distribute_module
- model = kwargs.get("model")
- empty_param = kwargs.get("empty_param")
- casting_dtype = kwargs.get("casting_dtype")
- to_contiguous = kwargs.get("to_contiguous")
- rank = kwargs.get("rank")
- device_mesh = kwargs.get("device_mesh")
- for proj in ["gate_up_proj", "down_proj"]:
- if proj in param_name:
- if device_mesh is not None:
- param_value = shard_and_distribute_module(
- model,
- param_value,
- empty_param,
- dq_param_name,
- casting_dtype,
- to_contiguous,
- rank,
- device_mesh,
- )
- blocks_attr = f"{proj}_blocks"
- scales_attr = f"{proj}_scales"
- setattr(module, param_name.rsplit(".", 1)[1], param_value)
- if hasattr(module, blocks_attr) and hasattr(module, scales_attr):
- dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
- setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
- delattr(module, blocks_attr)
- delattr(module, scales_attr)
- def dequantize_convertops(blocks, scales):
- dequantized = convert_moe_packed_tensors(blocks, scales)
- return torch.nn.Parameter(dequantized)
- def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs):
- """
- This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
- """
- PrecisionConfig, FlexCtx, InFlexData = (
- triton_kernels_hub.matmul_ogs.PrecisionConfig,
- triton_kernels_hub.matmul_ogs.FlexCtx,
- triton_kernels_hub.matmul_ogs.InFlexData,
- )
- from ..integrations.tensor_parallel import shard_and_distribute_module
- model = kwargs.get("model")
- empty_param = kwargs.get("empty_param")
- casting_dtype = kwargs.get("casting_dtype")
- to_contiguous = kwargs.get("to_contiguous")
- rank = kwargs.get("rank")
- device_mesh = kwargs.get("device_mesh")
- if "blocks" in param_name:
- proj = param_name.split(".")[-1].split("_blocks")[0]
- if "scales" in param_name:
- proj = param_name.split(".")[-1].split("_scales")[0]
- if device_mesh is not None:
- shard_and_distribute_module(
- model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh
- )
- else:
- setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False))
- blocks_attr = f"{proj}_blocks"
- scales_attr = f"{proj}_scales"
- blocks = getattr(module, blocks_attr) # at this point values were loaded from ckpt
- scales = getattr(module, scales_attr)
- # Check if both blocks and scales both not on meta device
- if blocks.device.type != "meta" and scales.device.type != "meta":
- local_experts = blocks.size(0)
- if proj == "gate_up_proj":
- blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1)
- else:
- blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
- if (
- getattr(target_device, "type", target_device) == "cpu"
- and hasattr(torch, "accelerator")
- and torch.accelerator.current_accelerator() is not None
- ):
- target_device = torch.accelerator.current_accelerator().type
- blocks = blocks.to(target_device).contiguous()
- scales = scales.to(target_device).contiguous()
- with on_device(target_device):
- triton_weight_tensor, weight_scale = swizzle_mxfp4(
- blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub
- )
- # need to overwrite the shapes for the kernels
- if proj == "gate_up_proj":
- triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2])
- else:
- triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size])
- # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor
- setattr(module, proj, triton_weight_tensor)
- setattr(
- module,
- f"{proj}_precision_config",
- PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
- )
- # delete blocks and scales
- delattr(module, scales_attr)
- delattr(module, blocks_attr)
- del blocks
- def swizzle_mxfp4_convertops(blocks, scales, module, proj, target_device, triton_kernels_hub):
- """
- This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
- """
- PrecisionConfig, FlexCtx, InFlexData = (
- triton_kernels_hub.matmul_ogs.PrecisionConfig,
- triton_kernels_hub.matmul_ogs.FlexCtx,
- triton_kernels_hub.matmul_ogs.InFlexData,
- )
- local_experts = blocks.size(0)
- if (
- getattr(target_device, "type", target_device) == "cpu"
- and hasattr(torch, "accelerator")
- and torch.accelerator.current_accelerator() is not None
- ):
- target_device = torch.accelerator.current_accelerator().type
- blocks = blocks.to(target_device).contiguous()
- scales = scales.to(target_device).contiguous()
- if proj == "gate_up_proj":
- blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1)
- else:
- blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
- with on_device(target_device):
- triton_weight_tensor, weight_scale = swizzle_mxfp4(
- blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub
- )
- # need to overwrite the shapes for the kernels
- if proj == "gate_up_proj":
- triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2])
- else:
- triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size])
- # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It's like a subtensor
- # Since the Experts module registers gate_up_proj and down_proj as nn.Parameters, we need to remove them so we can attach the Triton tensor
- if proj in module._parameters:
- # Remove the nn.Parameter registration so we can attach the Triton tensor
- del module._parameters[proj]
- setattr(module, proj, triton_weight_tensor)
- setattr(
- module,
- f"{proj}_precision_config",
- PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
- )
- def replace_with_mxfp4_linear(model, quantization_config=None, modules_to_not_convert: list[str] | None = None):
- """
- Public method that replaces the expert layers of the given model with mxfp4 quantized layers.
- Args:
- model (`torch.nn.Module`):
- The model to convert, can be any `torch.nn.Module` instance.
- quantization_config (`Mxfp4Config`, defaults to `None`):
- The quantization config object that contains the quantization parameters.
- modules_to_not_convert (`list`, *optional*, defaults to `None`):
- A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
- converted.
- """
- if quantization_config.dequantize:
- return model
- from .hub_kernels import get_kernel
- global triton_kernels_hub
- triton_kernels_hub = get_kernel("kernels-community/gpt-oss-triton-kernels")
- has_been_replaced = False
- for module_name, module in model.named_modules():
- if not should_convert_module(module_name, modules_to_not_convert):
- continue
- if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
- with torch.device("meta"):
- model.set_submodule(module_name, Mxfp4GptOssExperts(model.config))
- has_been_replaced = True
- if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize:
- from types import MethodType
- module.forward = MethodType(mlp_forward, module)
- if not has_been_replaced:
- logger.warning(
- "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model."
- " Please double check your model architecture, or submit an issue on github if you think this is"
- " a bug."
- )
- return model
|