mxfp4.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..utils import is_torch_available, logging
  15. if is_torch_available():
  16. import torch
  17. from torch import nn
  18. from contextlib import contextmanager
  19. from ..core_model_loading import ConversionOps, _IdentityOp
  20. from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
  21. logger = logging.get_logger(__name__)
  22. FP4_VALUES = [
  23. +0.0,
  24. +0.5,
  25. +1.0,
  26. +1.5,
  27. +2.0,
  28. +3.0,
  29. +4.0,
  30. +6.0,
  31. -0.0,
  32. -0.5,
  33. -1.0,
  34. -1.5,
  35. -2.0,
  36. -3.0,
  37. -4.0,
  38. -6.0,
  39. ]
  40. @contextmanager
  41. def on_device(dev):
  42. if is_torch_available():
  43. import torch
  44. if isinstance(dev, torch.Tensor):
  45. dev = dev.device
  46. elif isinstance(dev, str):
  47. dev = torch.device(dev)
  48. dev_type = getattr(dev, "type", None)
  49. if dev_type == "cuda":
  50. with torch.cuda.device(dev):
  51. yield
  52. return
  53. if dev_type == "xpu" and hasattr(torch, "xpu"):
  54. with torch.xpu.device(dev):
  55. yield
  56. return
  57. # other: CPU
  58. yield
  59. class Mxfp4Quantize(ConversionOps):
  60. def __init__(self, hf_quantizer):
  61. self.hf_quantizer = hf_quantizer
  62. def convert(
  63. self,
  64. input_dict: dict[str, torch.Tensor],
  65. model: torch.nn.Module | None = None,
  66. missing_keys: list[str] | None = None,
  67. full_layer_name: str | None = None,
  68. **kwargs,
  69. ) -> dict[str, torch.Tensor]:
  70. _, value = tuple(input_dict.items())[0]
  71. value = value[0] if isinstance(value, list) else value
  72. module, _ = get_module_from_name(model, full_layer_name)
  73. with torch.device(value.device):
  74. if isinstance(module, Mxfp4GptOssExperts):
  75. triton_weight_tensor, weight_scale = quantize_to_mxfp4(value.transpose(-1, -2), triton_kernels_hub)
  76. PrecisionConfig, FlexCtx, InFlexData = (
  77. triton_kernels_hub.matmul_ogs.PrecisionConfig,
  78. triton_kernels_hub.matmul_ogs.FlexCtx,
  79. triton_kernels_hub.matmul_ogs.InFlexData,
  80. )
  81. triton_weight_tensor, weight_scale = swizzle_mxfp4(
  82. triton_weight_tensor, weight_scale, triton_kernels_hub
  83. )
  84. proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
  85. if proj in module._parameters:
  86. # Remove the nn.Parameter registration so we can attach the Triton tensor
  87. del module._parameters[proj]
  88. setattr(module, proj, triton_weight_tensor)
  89. setattr(
  90. module,
  91. f"{proj}_precision_config",
  92. PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
  93. )
  94. missing_keys.discard(f"{full_layer_name}")
  95. module._is_hf_initialized = True
  96. return {}
  97. class Mxfp4Dequantize(ConversionOps):
  98. def __init__(self, hf_quantizer):
  99. self.hf_quantizer = hf_quantizer
  100. def convert(
  101. self,
  102. input_dict: dict[str, torch.Tensor],
  103. model: torch.nn.Module | None = None,
  104. full_layer_name: str | None = None,
  105. missing_keys: list[str] | None = None,
  106. **kwargs,
  107. ) -> dict[str, torch.Tensor]:
  108. param_data = {}
  109. proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
  110. if f"{proj}_blocks" in input_dict.keys():
  111. if isinstance(input_dict[f"{proj}_blocks"], list):
  112. param_data[f"{proj}_blocks"] = input_dict[f"{proj}_blocks"][0]
  113. else:
  114. param_data[f"{proj}_blocks"] = input_dict[f"{proj}_blocks"]
  115. if f"{proj}_scales" in input_dict.keys():
  116. if isinstance(input_dict[f"{proj}_scales"], list):
  117. param_data[f"{proj}_scales"] = input_dict[f"{proj}_scales"][0]
  118. else:
  119. param_data[f"{proj}_scales"] = input_dict[f"{proj}_scales"]
  120. # Here we are dequantizing the weights
  121. dequantized = dequantize_convertops(param_data[f"{proj}_blocks"], param_data[f"{proj}_scales"])
  122. return {full_layer_name: dequantized}
  123. @property
  124. def reverse_op(self) -> "ConversionOps":
  125. return _IdentityOp()
  126. class Mxfp4Deserialize(ConversionOps):
  127. def __init__(self, hf_quantizer):
  128. self.hf_quantizer = hf_quantizer
  129. def convert(
  130. self,
  131. input_dict: dict[str, torch.Tensor],
  132. model: torch.nn.Module | None = None,
  133. full_layer_name: str | None = None,
  134. missing_keys: list[str] | None = None,
  135. **kwargs,
  136. ) -> dict[str, torch.Tensor]:
  137. param_data = {}
  138. proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
  139. if f"{proj}_blocks" in input_dict.keys():
  140. if isinstance(input_dict[f"{proj}_blocks"], list):
  141. param_data[f"{proj}_blocks"] = input_dict[f"{proj}_blocks"][0]
  142. else:
  143. param_data[f"{proj}_blocks"] = input_dict[f"{proj}_blocks"]
  144. if f"{proj}_scales" in input_dict.keys():
  145. if isinstance(input_dict[f"{proj}_scales"], list):
  146. param_data[f"{proj}_scales"] = input_dict[f"{proj}_scales"][0]
  147. else:
  148. param_data[f"{proj}_scales"] = input_dict[f"{proj}_scales"]
  149. # Eagerly set tensors on the module and perform swizzle
  150. module, _ = get_module_from_name(model, full_layer_name)
  151. swizzle_mxfp4_convertops(
  152. param_data[f"{proj}_blocks"],
  153. param_data[f"{proj}_scales"],
  154. module,
  155. proj,
  156. param_data[f"{proj}_blocks"].device,
  157. triton_kernels_hub,
  158. )
  159. missing_keys.discard(f"{full_layer_name}")
  160. module._is_hf_initialized = True
  161. # We return an empty mapping since the module was updated in-place. This prevents
  162. # the loader from trying to materialize the original meta-parameter names again.
  163. # We don't use set_param_for_module since it expects mainly a torch.nn.Parameter or a safetensors pointer
  164. return {}
  165. @property
  166. def reverse_op(self) -> ConversionOps:
  167. return Mxfp4ReverseDeserialize(self.hf_quantizer)
  168. class Mxfp4ReverseDeserialize(ConversionOps):
  169. def __init__(self, hf_quantizer):
  170. self.hf_quantizer = hf_quantizer
  171. def convert(
  172. self,
  173. input_dict: dict[str, torch.Tensor],
  174. model: torch.nn.Module | None = None,
  175. full_layer_name: str | None = None,
  176. missing_keys: list[str] | None = None,
  177. **kwargs,
  178. ) -> dict[str, torch.Tensor]:
  179. num_local_experts = getattr(model.config, "num_local_experts", 32)
  180. hidden_size = getattr(model.config, "hidden_size", 2880)
  181. proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
  182. name = full_layer_name.rsplit("_", 1)[0]
  183. module, _ = get_module_from_name(model, full_layer_name)
  184. state_dict = {}
  185. if isinstance(module, Mxfp4GptOssExperts):
  186. if "bias" in full_layer_name:
  187. name = full_layer_name.replace("_blocks", "")
  188. state_dict[name] = getattr(module, proj + "_bias")
  189. return state_dict
  190. if "gate_up_proj" in full_layer_name:
  191. state_dict[f"{name}_blocks"] = (
  192. module.gate_up_proj.storage.layout.unswizzle_data(module.gate_up_proj.storage.data)
  193. .transpose(-1, -2)
  194. .reshape(num_local_experts, -1, 90, 16)
  195. )
  196. state_dict[f"{name}_scales"] = (
  197. module.gate_up_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
  198. module.gate_up_proj_precision_config.weight_scale.storage.data
  199. ).transpose(-1, -2)
  200. )
  201. else:
  202. state_dict[f"{name}_blocks"] = (
  203. module.down_proj.storage.layout.unswizzle_data(module.down_proj.storage.data)
  204. .transpose(-1, -2)
  205. .reshape(num_local_experts, hidden_size, 90, -1)
  206. )
  207. state_dict[f"{name}_scales"] = (
  208. module.down_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
  209. module.down_proj_precision_config.weight_scale.storage.data
  210. ).transpose(-1, -2)
  211. )
  212. return state_dict
  213. # Copied from GPT_OSS repo and vllm
  214. def quantize_to_mxfp4(w, triton_kernels_hub):
  215. downcast_to_mxfp_torch = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp_torch
  216. w, w_scale = downcast_to_mxfp_torch(w.to(torch.bfloat16), torch.uint8, axis=1)
  217. return w, w_scale
  218. def swizzle_mxfp4(w, w_scale, triton_kernels_hub):
  219. """
  220. Changes the layout of the tensors depending on the hardware
  221. """
  222. FP4, convert_layout, wrap_torch_tensor = (
  223. triton_kernels_hub.tensor.FP4,
  224. triton_kernels_hub.tensor.convert_layout,
  225. triton_kernels_hub.tensor.wrap_torch_tensor,
  226. )
  227. layout = triton_kernels_hub.tensor_details.layout
  228. StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
  229. value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
  230. w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
  231. w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
  232. return w, w_scale
  233. # Mostly copied from GPT_OSS repo
  234. # TODO: Add absolute link when the repo is public
  235. def _convert_moe_packed_tensors(
  236. blocks,
  237. scales,
  238. *,
  239. dtype: torch.dtype = torch.bfloat16,
  240. rows_per_chunk: int = 32768 * 1024, # TODO these values are not here by mistake ;)
  241. ) -> torch.Tensor:
  242. """
  243. Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
  244. pass of GPT_OSS.
  245. """
  246. import math
  247. blocks = blocks.to(torch.uint8)
  248. scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
  249. assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}"
  250. lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
  251. *prefix_shape, G, B = blocks.shape
  252. rows_total = math.prod(prefix_shape) * G
  253. blocks = blocks.reshape(rows_total, B)
  254. scales = scales.reshape(rows_total, 1)
  255. out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
  256. for r0 in range(0, rows_total, rows_per_chunk):
  257. r1 = min(r0 + rows_per_chunk, rows_total)
  258. blk = blocks[r0:r1]
  259. exp = scales[r0:r1]
  260. sub = out[r0:r1]
  261. # This vector is only used to index into `lut`, but is hugeee in GPU memory so we delete it immediately
  262. idx_lo = (blk & 0x0F).to(torch.int)
  263. sub[:, 0::2] = lut[idx_lo]
  264. del idx_lo
  265. # This vector is only used to index into `lut`, but is hugeee in GPU memory so we delete it immediately
  266. idx_hi = (blk >> 4).to(torch.int)
  267. sub[:, 1::2] = lut[idx_hi]
  268. del idx_hi
  269. # Perform op
  270. torch.ldexp(sub, exp, out=sub)
  271. del blk, exp, sub
  272. out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
  273. return out.transpose(1, 2).contiguous()
  274. def convert_moe_packed_tensors(
  275. blocks,
  276. scales,
  277. *,
  278. dtype: torch.dtype = torch.bfloat16,
  279. rows_per_chunk: int = 32768 * 1024, # TODO these values are not here by mistake ;)
  280. ) -> torch.Tensor:
  281. """
  282. Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
  283. pass of GPT_OSS.
  284. """
  285. # Since the intermediate ops requite A LOT of memory, in very constrained device_map="auto" settings
  286. # it may OOM, hence this wrapper and move back to cpu if needed
  287. # torch statistics are not accurate enough to estimate if we will have enough memory due to fragmentation and
  288. # in-place operation on non-contiguous tensors (may sometimes require more temporary copies)
  289. try:
  290. return _convert_moe_packed_tensors(blocks, scales, dtype=dtype, rows_per_chunk=rows_per_chunk)
  291. # In the case of OOM due to very tight device_map, we convert and return on cpu - it will then be put back on correct
  292. # devide with the accelerate dispatch (doing it right away may still lead to OOM, but more memory is available later)
  293. except torch.OutOfMemoryError:
  294. blocks = blocks.to("cpu")
  295. scales = scales.to("cpu")
  296. return _convert_moe_packed_tensors(blocks, scales, dtype=dtype, rows_per_chunk=rows_per_chunk)
  297. class Mxfp4GptOssExperts(nn.Module):
  298. def __init__(self, config):
  299. super().__init__()
  300. self.num_experts = config.num_local_experts
  301. self.intermediate_size = config.intermediate_size
  302. self.hidden_size = config.hidden_size
  303. self.gate_up_proj = nn.Parameter(
  304. torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
  305. requires_grad=False,
  306. )
  307. self.gate_up_proj_bias = nn.Parameter(
  308. torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
  309. )
  310. self.down_proj = nn.Parameter(
  311. torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
  312. requires_grad=False,
  313. )
  314. self.down_proj_bias = nn.Parameter(
  315. torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
  316. )
  317. self.alpha = 1.702
  318. self.limit = getattr(config, "swiglu_limit", 7.0)
  319. self.gate_up_proj_precision_config = None
  320. self.down_proj_precision_config = None
  321. self.limit = getattr(config, "swiglu_limit", 7.0)
  322. def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
  323. FnSpecs, FusedActivation, matmul_ogs = (
  324. triton_kernels_hub.matmul_ogs.FnSpecs,
  325. triton_kernels_hub.matmul_ogs.FusedActivation,
  326. triton_kernels_hub.matmul_ogs.matmul_ogs,
  327. )
  328. swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
  329. with on_device(hidden_states.device):
  330. act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2)
  331. intermediate_cache1 = matmul_ogs(
  332. hidden_states,
  333. self.gate_up_proj,
  334. self.gate_up_proj_bias.to(torch.float32),
  335. routing_data,
  336. gather_indx=gather_idx,
  337. precision_config=self.gate_up_proj_precision_config,
  338. gammas=None,
  339. fused_activation=act,
  340. )
  341. intermediate_cache3 = matmul_ogs(
  342. intermediate_cache1,
  343. self.down_proj,
  344. self.down_proj_bias.to(torch.float32),
  345. routing_data,
  346. scatter_indx=scatter_idx,
  347. precision_config=self.down_proj_precision_config,
  348. gammas=routing_data.gate_scal,
  349. )
  350. return intermediate_cache3
  351. # Adapted from GPT_OSS repo
  352. # TODO: Add absolute link when the repo is public
  353. def routing_torch_dist(
  354. logits,
  355. n_expts_act,
  356. ):
  357. import os
  358. GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
  359. triton_kernels_hub.routing.GatherIndx,
  360. triton_kernels_hub.routing.RoutingData,
  361. triton_kernels_hub.routing.ScatterIndx,
  362. triton_kernels_hub.routing.compute_expt_data_torch,
  363. )
  364. with on_device(logits.device):
  365. world_size = torch.distributed.get_world_size()
  366. rank = int(os.environ.get("LOCAL_RANK", "0"))
  367. replace_value = -1
  368. n_tokens = logits.shape[0]
  369. n_expts_tot = logits.shape[1]
  370. n_local_experts = n_expts_tot // world_size
  371. local_expert_start = rank * n_local_experts
  372. local_expert_end = (rank + 1) * n_local_experts
  373. n_gates_pad = n_tokens * n_expts_act
  374. def topk(vals, k):
  375. tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
  376. tk_indx = tk_indx.long()
  377. tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
  378. return tk_val, tk_indx.int()
  379. expt_scal, expt_indx = topk(logits, n_expts_act)
  380. expt_scal = torch.softmax(expt_scal, dim=-1)
  381. expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
  382. expt_scal = torch.gather(expt_scal, 1, sort_indices)
  383. # Flatten and mask for local experts
  384. expt_scal = expt_scal.reshape(-1)
  385. hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start:local_expert_end]
  386. expt_indx = expt_indx.view(-1).to(torch.int32)
  387. # we use a large value to replace the indices that are not in the local expert range
  388. var = 1000
  389. expt_indx = torch.where(expt_indx < local_expert_start, var, expt_indx)
  390. topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32)
  391. gate_indx = torch.argsort(topk_indx).to(torch.int32)
  392. expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value)
  393. expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value)
  394. gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx)
  395. gate_scal = expt_scal[topk_indx]
  396. topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx)
  397. # # Routing metadata for local expert computation
  398. gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
  399. scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
  400. expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad)
  401. hit_experts = n_expts_act
  402. return RoutingData(gate_scal, hist, n_local_experts, hit_experts, expt_data), gather_indx, scatter_indx
  403. def mlp_forward(self, hidden_states):
  404. import torch.distributed as dist
  405. if dist.is_available() and dist.is_initialized() and hasattr(self, "_is_hooked"):
  406. routing = routing_torch_dist
  407. else:
  408. routing = triton_kernels_hub.routing.routing
  409. batch_size = hidden_states.shape[0]
  410. hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
  411. router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
  412. with on_device(router_logits.device):
  413. routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
  414. routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx=scatter_idx)
  415. routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
  416. return routed_out, router_logits
  417. def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
  418. from ..integrations.tensor_parallel import shard_and_distribute_module
  419. model = kwargs.get("model")
  420. empty_param = kwargs.get("empty_param")
  421. casting_dtype = kwargs.get("casting_dtype")
  422. to_contiguous = kwargs.get("to_contiguous")
  423. rank = kwargs.get("rank")
  424. device_mesh = kwargs.get("device_mesh")
  425. for proj in ["gate_up_proj", "down_proj"]:
  426. if proj in param_name:
  427. if device_mesh is not None:
  428. param_value = shard_and_distribute_module(
  429. model,
  430. param_value,
  431. empty_param,
  432. dq_param_name,
  433. casting_dtype,
  434. to_contiguous,
  435. rank,
  436. device_mesh,
  437. )
  438. blocks_attr = f"{proj}_blocks"
  439. scales_attr = f"{proj}_scales"
  440. setattr(module, param_name.rsplit(".", 1)[1], param_value)
  441. if hasattr(module, blocks_attr) and hasattr(module, scales_attr):
  442. dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
  443. setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
  444. delattr(module, blocks_attr)
  445. delattr(module, scales_attr)
  446. def dequantize_convertops(blocks, scales):
  447. dequantized = convert_moe_packed_tensors(blocks, scales)
  448. return torch.nn.Parameter(dequantized)
  449. def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs):
  450. """
  451. This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
  452. """
  453. PrecisionConfig, FlexCtx, InFlexData = (
  454. triton_kernels_hub.matmul_ogs.PrecisionConfig,
  455. triton_kernels_hub.matmul_ogs.FlexCtx,
  456. triton_kernels_hub.matmul_ogs.InFlexData,
  457. )
  458. from ..integrations.tensor_parallel import shard_and_distribute_module
  459. model = kwargs.get("model")
  460. empty_param = kwargs.get("empty_param")
  461. casting_dtype = kwargs.get("casting_dtype")
  462. to_contiguous = kwargs.get("to_contiguous")
  463. rank = kwargs.get("rank")
  464. device_mesh = kwargs.get("device_mesh")
  465. if "blocks" in param_name:
  466. proj = param_name.split(".")[-1].split("_blocks")[0]
  467. if "scales" in param_name:
  468. proj = param_name.split(".")[-1].split("_scales")[0]
  469. if device_mesh is not None:
  470. shard_and_distribute_module(
  471. model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh
  472. )
  473. else:
  474. setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False))
  475. blocks_attr = f"{proj}_blocks"
  476. scales_attr = f"{proj}_scales"
  477. blocks = getattr(module, blocks_attr) # at this point values were loaded from ckpt
  478. scales = getattr(module, scales_attr)
  479. # Check if both blocks and scales both not on meta device
  480. if blocks.device.type != "meta" and scales.device.type != "meta":
  481. local_experts = blocks.size(0)
  482. if proj == "gate_up_proj":
  483. blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1)
  484. else:
  485. blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
  486. if (
  487. getattr(target_device, "type", target_device) == "cpu"
  488. and hasattr(torch, "accelerator")
  489. and torch.accelerator.current_accelerator() is not None
  490. ):
  491. target_device = torch.accelerator.current_accelerator().type
  492. blocks = blocks.to(target_device).contiguous()
  493. scales = scales.to(target_device).contiguous()
  494. with on_device(target_device):
  495. triton_weight_tensor, weight_scale = swizzle_mxfp4(
  496. blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub
  497. )
  498. # need to overwrite the shapes for the kernels
  499. if proj == "gate_up_proj":
  500. triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2])
  501. else:
  502. triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size])
  503. # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor
  504. setattr(module, proj, triton_weight_tensor)
  505. setattr(
  506. module,
  507. f"{proj}_precision_config",
  508. PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
  509. )
  510. # delete blocks and scales
  511. delattr(module, scales_attr)
  512. delattr(module, blocks_attr)
  513. del blocks
  514. def swizzle_mxfp4_convertops(blocks, scales, module, proj, target_device, triton_kernels_hub):
  515. """
  516. This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
  517. """
  518. PrecisionConfig, FlexCtx, InFlexData = (
  519. triton_kernels_hub.matmul_ogs.PrecisionConfig,
  520. triton_kernels_hub.matmul_ogs.FlexCtx,
  521. triton_kernels_hub.matmul_ogs.InFlexData,
  522. )
  523. local_experts = blocks.size(0)
  524. if (
  525. getattr(target_device, "type", target_device) == "cpu"
  526. and hasattr(torch, "accelerator")
  527. and torch.accelerator.current_accelerator() is not None
  528. ):
  529. target_device = torch.accelerator.current_accelerator().type
  530. blocks = blocks.to(target_device).contiguous()
  531. scales = scales.to(target_device).contiguous()
  532. if proj == "gate_up_proj":
  533. blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1)
  534. else:
  535. blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
  536. with on_device(target_device):
  537. triton_weight_tensor, weight_scale = swizzle_mxfp4(
  538. blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub
  539. )
  540. # need to overwrite the shapes for the kernels
  541. if proj == "gate_up_proj":
  542. triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2])
  543. else:
  544. triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size])
  545. # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It's like a subtensor
  546. # Since the Experts module registers gate_up_proj and down_proj as nn.Parameters, we need to remove them so we can attach the Triton tensor
  547. if proj in module._parameters:
  548. # Remove the nn.Parameter registration so we can attach the Triton tensor
  549. del module._parameters[proj]
  550. setattr(module, proj, triton_weight_tensor)
  551. setattr(
  552. module,
  553. f"{proj}_precision_config",
  554. PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
  555. )
  556. def replace_with_mxfp4_linear(model, quantization_config=None, modules_to_not_convert: list[str] | None = None):
  557. """
  558. Public method that replaces the expert layers of the given model with mxfp4 quantized layers.
  559. Args:
  560. model (`torch.nn.Module`):
  561. The model to convert, can be any `torch.nn.Module` instance.
  562. quantization_config (`Mxfp4Config`, defaults to `None`):
  563. The quantization config object that contains the quantization parameters.
  564. modules_to_not_convert (`list`, *optional*, defaults to `None`):
  565. A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
  566. converted.
  567. """
  568. if quantization_config.dequantize:
  569. return model
  570. from .hub_kernels import get_kernel
  571. global triton_kernels_hub
  572. triton_kernels_hub = get_kernel("kernels-community/gpt-oss-triton-kernels")
  573. has_been_replaced = False
  574. for module_name, module in model.named_modules():
  575. if not should_convert_module(module_name, modules_to_not_convert):
  576. continue
  577. if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
  578. with torch.device("meta"):
  579. model.set_submodule(module_name, Mxfp4GptOssExperts(model.config))
  580. has_been_replaced = True
  581. if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize:
  582. from types import MethodType
  583. module.forward = MethodType(mlp_forward, module)
  584. if not has_been_replaced:
  585. logger.warning(
  586. "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model."
  587. " Please double check your model architecture, or submit an issue on github if you think this is"
  588. " a bug."
  589. )
  590. return model