tensor_parallel.py 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import math
  16. import operator
  17. import os
  18. import re
  19. from functools import reduce
  20. from ..distributed import DistributedConfig
  21. from ..utils import is_torch_greater_or_equal, logging
  22. from ..utils.generic import GeneralInterface
  23. from ..utils.import_utils import is_torch_available
  24. if is_torch_available():
  25. import torch
  26. import torch.distributed as dist
  27. from torch import nn
  28. # Cache this result has it's a C FFI call which can be pretty time-consuming
  29. _torch_distributed_available = torch.distributed.is_available()
  30. logger = logging.get_logger(__name__)
  31. def initialize_tensor_parallelism(
  32. tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
  33. ):
  34. r"""
  35. Sets up the device mesh and initialized the backend for tensor parallelism.
  36. This function is called when the model is loaded and the TP plan is set to 'auto'.
  37. """
  38. if tp_size is not None and tp_plan is None:
  39. raise ValueError("tp_plan has to be set when tp_size is passed.")
  40. if tp_plan is not None and device_map is not None:
  41. raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
  42. if device_mesh is None:
  43. if not is_torch_greater_or_equal("2.5"):
  44. raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
  45. # Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
  46. device_type = torch._C._get_accelerator().type
  47. if device_type == "mps":
  48. raise RuntimeError("Tensor parallelism is not supported on MPS devices.")
  49. current_device = getattr(torch, device_type)
  50. if not torch.distributed.is_initialized():
  51. try:
  52. rank = int(os.environ["RANK"])
  53. local_rank = int(os.environ["LOCAL_RANK"])
  54. world_size = int(os.environ["WORLD_SIZE"])
  55. backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl", "neuron": "neuron"}
  56. backend = backend_map.get(device_type)
  57. torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
  58. current_device = getattr(torch, device_type)
  59. if device_type != "cpu":
  60. current_device.set_device(local_rank)
  61. except Exception as e:
  62. raise OSError(
  63. "We tried to initialize torch.distributed for you, but it failed. Make "
  64. "sure you init torch distributed in your script to use `tp_plan`."
  65. ) from e
  66. if device_type != "cpu":
  67. current_device.set_device(int(os.environ["LOCAL_RANK"]))
  68. index = current_device.current_device()
  69. tp_device = torch.device(device_type, index)
  70. device_map = tp_device
  71. else:
  72. tp_device = torch.device(device_type)
  73. device_map = device_type or {}
  74. tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
  75. device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
  76. else:
  77. if device_mesh.ndim > 1:
  78. if "tp" not in device_mesh.mesh_dim_names:
  79. raise ValueError(
  80. "When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
  81. "Please provide a valid `device_mesh`."
  82. )
  83. device_mesh = device_mesh["tp"]
  84. tp_size = device_mesh.size()
  85. device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
  86. return device_map, device_mesh, tp_size
  87. def replace_layer_number_by_wildcard(name: str) -> str:
  88. """
  89. Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
  90. a dot (`.`) and the end of the string.
  91. This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
  92. numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
  93. """
  94. return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
  95. def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
  96. """
  97. Get the TP style for a parameter from the TP plan.
  98. The TP plan is a dictionary that maps parameter names to TP styles.
  99. The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
  100. The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
  101. not parent classes for `post_init` calls
  102. """
  103. generic_param_name = replace_layer_number_by_wildcard(parameter_name)
  104. if generic_param_name in tp_plan:
  105. return tp_plan[generic_param_name]
  106. elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
  107. return tp_plan[module_name]
  108. return None
  109. # =============================================================================
  110. # Tensor Sharding Utilities
  111. # =============================================================================
  112. if is_torch_available():
  113. str_to_dtype = {
  114. "BOOL": torch.bool,
  115. "U8": torch.uint8,
  116. "I8": torch.int8,
  117. "I16": torch.int16,
  118. "F16": torch.float16,
  119. "BF16": torch.bfloat16,
  120. "I32": torch.int32,
  121. "F32": torch.float32,
  122. "F64": torch.float64,
  123. "I64": torch.int64,
  124. "F8_E4M3": torch.float8_e4m3fn,
  125. }
  126. def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
  127. """
  128. Convert block count or proportions to block sizes.
  129. This function accepts
  130. - The number of blocks (int), in which case the block size is
  131. total_size//blocks; or
  132. - A list of block sizes (list[int]).
  133. In the second case, if sum(blocks) < total_size, the ratios between
  134. the block sizes will be preserved. For instance, if blocks is
  135. [2, 1, 1] and total_size is 1024, the returned block sizes are
  136. [512, 256, 256].
  137. """
  138. if isinstance(blocks, list):
  139. total_blocks = sum(blocks)
  140. assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
  141. part_size = total_size // total_blocks
  142. return [part_size * block for block in blocks]
  143. else:
  144. assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
  145. single_size = total_size // blocks
  146. return [single_size] * blocks
  147. def get_packed_weights(param, empty_param, device_mesh, rank, dim):
  148. """
  149. When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
  150. So if you have: gate_proj ( 16, 5120, 8190)
  151. and up_proj ( 16, 5120, 8190)
  152. packed as gate_up_proj ( 16, 5120, 2 * 8190)
  153. And you shard along the last dimension, you need to interleave the gate and up values:
  154. Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly.
  155. Let's take TP_size = 4 for an example:
  156. Packed tensor `gate_up_proj`
  157. ---------------------------------------------------------------
  158. [ G0 G1 G2 G3 | G4 G5 G6 G7 | ... | U0 U1 U2 U3 | U4 U5 U6 U7 | ... ]
  159. ↑─────────────↑ ↑─────────────↑ ↑─────────────↑ ↑─────────────↑
  160. Gate Slice 0 Gate Slice 1 Up Slice 0 Up Slice 1
  161. Explanation:
  162. - The first half of the tensor (left of the center) holds the gate_proj values.
  163. - The second half (right of the center) holds the up_proj values.
  164. - For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity.
  165. - Each shard receives one slice from the gate part and the corresponding slice from the up part.
  166. For instance:
  167. • Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ]
  168. • Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ]
  169. • … and so on.
  170. This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism.
  171. """
  172. slice_ = param
  173. total_size = empty_param.shape[dim]
  174. world_size = device_mesh.size()
  175. block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2)
  176. tensors_slices = []
  177. block_offset = 0
  178. for block_size in block_sizes:
  179. shard_block_size = block_size // world_size
  180. start = rank * shard_block_size
  181. stop = (rank + 1) * shard_block_size
  182. tensors_slices += range(block_offset + start, block_offset + stop)
  183. block_offset += block_size
  184. slice_dtype = slice_.get_dtype()
  185. # Handle F8_E4M3 dtype by converting to float16 before slicing
  186. # Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
  187. casted = False
  188. if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2":
  189. slice_ = slice_[...].to(torch.float16)
  190. casted = True
  191. if dim == 0:
  192. tensor = slice_[tensors_slices, ...]
  193. elif dim == 1 or dim == -2:
  194. tensor = slice_[:, tensors_slices, ...]
  195. elif dim == 2 or dim == -1:
  196. tensor = slice_[..., tensors_slices]
  197. else:
  198. raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
  199. if casted:
  200. return tensor
  201. else:
  202. return tensor.to(str_to_dtype[slice_dtype])
  203. def repack_weights(
  204. packed_parameter: torch.Tensor,
  205. sharded_dim: int, # The dimension index in the global tensor that was sharded
  206. world_size: int,
  207. num_blocks: int = 2,
  208. ) -> torch.Tensor:
  209. """
  210. Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
  211. For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
  212. DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
  213. along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
  214. This is an inverse operation to get_packed_weights.
  215. Args:
  216. reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
  217. sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
  218. world_size: The tensor parallel world size.
  219. num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
  220. Returns:
  221. The reordered tensor in canonical packed format.
  222. """
  223. if num_blocks != 2:
  224. raise ValueError(
  225. "Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
  226. )
  227. actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
  228. total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
  229. original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
  230. shard_chunk_size = original_block_size_on_dim // world_size
  231. prefix_shape = packed_parameter.shape[:actual_sharded_dim]
  232. suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
  233. tensor_view = packed_parameter.view(
  234. *prefix_shape,
  235. world_size,
  236. num_blocks,
  237. shard_chunk_size,
  238. *suffix_shape,
  239. )
  240. # Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
  241. # This groups all chunks of G together, then all chunks of U together.
  242. # Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
  243. # Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
  244. # Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
  245. axis_ws_abs = len(prefix_shape)
  246. axis_npp_abs = len(prefix_shape) + 1
  247. permute_order = list(range(tensor_view.ndim))
  248. permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
  249. tensor_permuted = tensor_view.permute(*permute_order)
  250. # Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
  251. # The final shape should be the same as reconstructed_tensor.
  252. final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
  253. return final_ordered_tensor
  254. def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int | None = None):
  255. """
  256. Generalized tensor sharding across a multi-dimensional device mesh.
  257. Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
  258. Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
  259. `Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
  260. such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
  261. Case (1)
  262. empty_param (16, 5120, 8190)
  263. dim 0
  264. device_mesh.size() 4
  265. rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
  266. rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
  267. rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
  268. rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
  269. Case (2)
  270. empty_param (16, 5120, 8190)
  271. dim 0
  272. device_mesh.size() 14
  273. rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
  274. rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
  275. rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
  276. rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
  277. rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
  278. rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
  279. rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
  280. rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
  281. rank 8 gets (0, 5120, 8190)
  282. rank 9 gets (0, 5120, 8190)
  283. rank 10 gets (0, 5120, 8190)
  284. rank 11 gets (0, 5120, 8190)
  285. rank 12 gets (0, 5120, 8190)
  286. rank 13 gets (0, 5120, 8190)
  287. Case (3)
  288. empty_param (16, 5120, 8190)
  289. dim 0
  290. device_mesh.size() 3
  291. rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
  292. rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
  293. rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
  294. In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
  295. Args:
  296. param (torch.Tensor): The tensor to shard.
  297. empty_param (torch.Tensor): A tensor used for shape reference.
  298. device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh.
  299. rank (int): Global rank of the current process/device.
  300. dim (int): Dimension along which to shard the tensor.
  301. """
  302. param_dim = empty_param.ndim
  303. mesh_shape = device_mesh.shape
  304. world_size = reduce(operator.mul, mesh_shape)
  305. # Get param shape: works for both torch.Tensor and safetensors TensorInfo
  306. param_shape = list(param.shape) if isinstance(param, torch.Tensor) else param.get_shape()
  307. if dim < 0:
  308. dim = param_dim + dim
  309. if empty_param.dim() == 3 and dim == 1 and len(param_shape) == 2:
  310. dim = 0
  311. elif empty_param.dim() == 3 and dim == 2 and len(param_shape) == 2:
  312. dim = 1
  313. shard_size = math.ceil(param_shape[dim] / world_size)
  314. start = rank * shard_size
  315. end = min(start + shard_size, param_shape[dim])
  316. if dim >= param_dim:
  317. raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
  318. if rank >= world_size:
  319. raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
  320. # we have the full tensor not 1 part of it.
  321. # in that case, we just assume that the weight was properly saved
  322. # and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
  323. # to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
  324. # here we take care of potential chunking / layer split / layer chunking.
  325. # The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
  326. # actually we still shard dim=0 does not change
  327. # so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
  328. # tensor on a certain device (with the input tensor_index)
  329. if tensor_idx is not None and empty_param.dim() == 3 and dim == 0 and len(param_shape) == 2:
  330. # special case we don't "shard" just send this entire tensor to the correct rank.
  331. if start <= tensor_idx < end:
  332. # this tensor does need to be materialized on this device:
  333. return param[:]
  334. else:
  335. return torch.empty([], dtype=torch.int64, device=rank)
  336. slice_indices = [slice(None)] * len(param_shape)
  337. if start < param_shape[dim]:
  338. slice_indices[dim] = slice(start, end)
  339. param = param[tuple(slice_indices)]
  340. if isinstance(param, list): # TODO handle the modulelist case!
  341. param = [p[:] for p in param]
  342. return param
  343. param_shape[dim] = 0
  344. return torch.empty(tuple(param_shape), dtype=torch.int64) # empty allocates memory....
  345. def _split_along_last_dim(x, world_size):
  346. """Split tensor along last dimension into world_size chunks."""
  347. return torch.chunk(x, world_size, dim=-1)
  348. # =============================================================================
  349. # Distributed Communication Primitives
  350. # =============================================================================
  351. #
  352. # Naming convention:
  353. # - Functions describe their FORWARD behavior
  354. # - Backward behavior is the "conjugate" operation for gradient flow
  355. #
  356. # Available operations:
  357. # ┌────────────────────┬─────────────────────┬─────────────────────┐
  358. # │ Function │ Forward │ Backward │
  359. # ├────────────────────┼─────────────────────┼─────────────────────┤
  360. # │ all_reduce │ all-reduce (sum) │ identity │
  361. # │ all_reduce_backward│ identity │ all-reduce (sum) │
  362. # │ all_gather │ all-gather │ split (local chunk) │
  363. # │ split │ split (local chunk) │ all-gather │
  364. # │ reduce_scatter │ reduce-scatter │ all-gather │
  365. # └────────────────────┴─────────────────────┴─────────────────────┘
  366. # ===================
  367. class _AllReduceBackward(torch.autograd.Function):
  368. """Identity forward, all-reduce backward. Used before colwise layers (f in Megatron)."""
  369. @staticmethod
  370. def forward(ctx, x, device_mesh):
  371. ctx.device_mesh = device_mesh
  372. return x
  373. @staticmethod
  374. def backward(ctx, grad_output):
  375. device_mesh = ctx.device_mesh
  376. if device_mesh.size() == 1:
  377. return grad_output, None
  378. grad_output = grad_output.contiguous()
  379. dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
  380. return grad_output, None
  381. class _AllReduceForward(torch.autograd.Function):
  382. """All-reduce forward, identity backward. Used after rowwise layers (g in Megatron)."""
  383. @staticmethod
  384. def forward(ctx, x, device_mesh):
  385. if device_mesh.size() == 1:
  386. return x
  387. dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
  388. return x
  389. @staticmethod
  390. def backward(ctx, grad_output):
  391. return grad_output, None
  392. class _AllGather(torch.autograd.Function):
  393. """All-gather forward, split backward. Gathers sharded outputs."""
  394. @staticmethod
  395. def forward(ctx, x, device_mesh):
  396. ctx.device_mesh = device_mesh
  397. world_size = device_mesh.size()
  398. if world_size == 1:
  399. return x
  400. last_dim = x.dim() - 1
  401. rank = device_mesh.get_local_rank()
  402. group = device_mesh.get_group()
  403. x = x.contiguous()
  404. tensor_list = [torch.empty_like(x) for _ in range(world_size)]
  405. tensor_list[rank] = x
  406. dist.all_gather(tensor_list, x, group=group)
  407. return torch.cat(tensor_list, dim=last_dim).contiguous()
  408. @staticmethod
  409. def backward(ctx, grad_output):
  410. device_mesh = ctx.device_mesh
  411. world_size = device_mesh.size()
  412. if world_size == 1:
  413. return grad_output, None
  414. rank = device_mesh.get_local_rank()
  415. chunks = _split_along_last_dim(grad_output, world_size)
  416. return chunks[rank].contiguous(), None
  417. class _Split(torch.autograd.Function):
  418. """Split forward, all-gather backward. Scatters replicated input."""
  419. @staticmethod
  420. def forward(ctx, x, device_mesh):
  421. ctx.device_mesh = device_mesh
  422. world_size = device_mesh.size()
  423. if world_size == 1:
  424. return x
  425. rank = device_mesh.get_local_rank()
  426. chunks = _split_along_last_dim(x, world_size)
  427. return chunks[rank].contiguous()
  428. @staticmethod
  429. def backward(ctx, grad_output):
  430. device_mesh = ctx.device_mesh
  431. world_size = device_mesh.size()
  432. if world_size == 1:
  433. return grad_output, None
  434. last_dim = grad_output.dim() - 1
  435. rank = device_mesh.get_local_rank()
  436. group = device_mesh.get_group()
  437. grad_output = grad_output.contiguous()
  438. tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
  439. tensor_list[rank] = grad_output
  440. dist.all_gather(tensor_list, grad_output, group=group)
  441. return torch.cat(tensor_list, dim=last_dim).contiguous(), None
  442. class _ReduceScatter(torch.autograd.Function):
  443. """Reduce-scatter forward, all-gather backward. For sequence parallel."""
  444. @staticmethod
  445. def forward(ctx, x, device_mesh):
  446. ctx.device_mesh = device_mesh
  447. world_size = device_mesh.size()
  448. if world_size == 1:
  449. return x
  450. last_dim = x.dim() - 1
  451. group = device_mesh.get_group()
  452. input_chunks = list(x.chunk(world_size, dim=last_dim))
  453. output_shape = list(x.shape)
  454. output_shape[last_dim] //= world_size
  455. output = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  456. dist.reduce_scatter(output, input_chunks, op=dist.ReduceOp.SUM, group=group)
  457. return output
  458. @staticmethod
  459. def backward(ctx, grad_output):
  460. device_mesh = ctx.device_mesh
  461. world_size = device_mesh.size()
  462. if world_size == 1:
  463. return grad_output, None
  464. last_dim = grad_output.dim() - 1
  465. rank = device_mesh.get_local_rank()
  466. group = device_mesh.get_group()
  467. grad_output = grad_output.contiguous()
  468. tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
  469. tensor_list[rank] = grad_output
  470. dist.all_gather(tensor_list, grad_output, group=group)
  471. return torch.cat(tensor_list, dim=last_dim).contiguous(), None
  472. # =============================================================================
  473. # Convenience wrappers
  474. # =============================================================================
  475. def all_reduce_backward(x, device_mesh):
  476. """Identity forward, all-reduce backward. Use before colwise layers."""
  477. return _AllReduceBackward.apply(x, device_mesh)
  478. def all_reduce_forward(x, device_mesh):
  479. """All-reduce forward, identity backward. Use after rowwise layers."""
  480. return _AllReduceForward.apply(x, device_mesh)
  481. def all_gather(x, device_mesh):
  482. """All-gather forward, split backward."""
  483. return _AllGather.apply(x, device_mesh)
  484. def split(x, device_mesh):
  485. """Split forward, all-gather backward."""
  486. return _Split.apply(x, device_mesh)
  487. def reduce_scatter(x, device_mesh):
  488. """Reduce-scatter forward, all-gather backward."""
  489. return _ReduceScatter.apply(x, device_mesh)
  490. def distribute_module(
  491. module: nn.Module,
  492. device_mesh=None,
  493. input_fn=None,
  494. output_fn=None,
  495. ) -> nn.Module:
  496. """
  497. Copy pasted from torch's function but we remove the communications (partitioning)
  498. as well as buffer registering that is similarly not efficient.
  499. """
  500. if input_fn is not None:
  501. module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
  502. if output_fn is not None:
  503. module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
  504. return module
  505. class TensorParallelLayer:
  506. """General tensor parallel layer for transformers"""
  507. device_mesh = None
  508. rank = None
  509. empty_param = None
  510. def __init__(self, device_mesh=None, rank=None, empty_param=None):
  511. self.rank = rank
  512. self.device_mesh = device_mesh
  513. self.empty_param = empty_param
  514. def _prepare_input_fn(self, mod, inputs, device_mesh):
  515. raise NotImplementedError
  516. def _prepare_output_fn(self, mod, outputs, device_mesh):
  517. raise NotImplementedError
  518. def shard_tensor(
  519. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  520. ) -> torch.Tensor:
  521. raise NotImplementedError
  522. def prepare_module_tp(self, module: nn.Module, device_mesh, **kwargs) -> nn.Module:
  523. distribute_module(
  524. module,
  525. device_mesh,
  526. self._prepare_input_fn,
  527. self._prepare_output_fn,
  528. )
  529. def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
  530. """
  531. Compute the expected shape after TP sharding for a given full shape.
  532. Args:
  533. full_shape: The full (unsharded) parameter shape
  534. Returns:
  535. The expected sharded shape for this rank
  536. """
  537. # Default: no sharding, return full shape
  538. return tuple(full_shape)
  539. def update_module_attributes(self, module: nn.Module):
  540. """
  541. Update module attributes (e.g. in_features, out_features) to reflect sharded dimensions.
  542. Args:
  543. module: The module to update
  544. Returns:
  545. None, update the module in-place
  546. """
  547. pass
  548. class ColwiseParallel(TensorParallelLayer):
  549. """
  550. Column-wise parallel: weight is sharded on dim -2 (output features).
  551. Forward: input replicated -> output sharded on last dim.
  552. If gather_output=True, output is all-gathered to produce full tensor.
  553. """
  554. def __init__(self, gather_output: bool = False, **kwargs):
  555. super().__init__(**kwargs)
  556. self.gather_output = gather_output
  557. def _prepare_input_fn(self, mod, inputs, device_mesh):
  558. input_tensor = inputs[0] if inputs else inputs
  559. return all_reduce_backward(input_tensor, device_mesh)
  560. def _prepare_output_fn(self, mod, outputs, device_mesh):
  561. if self.gather_output:
  562. return all_gather(outputs, device_mesh)
  563. return outputs
  564. def shard_tensor(
  565. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  566. ) -> torch.Tensor:
  567. # If only 1 dim, shard this one (usually it's a `bias`)
  568. dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
  569. if dim == 1:
  570. parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
  571. else:
  572. parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
  573. return parameter.to(device=device, dtype=dtype)
  574. def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
  575. world_size = self.device_mesh.size()
  576. shape = list(full_shape)
  577. # Colwise shards dim -2, but 1D tensors (bias) shard on dim -1
  578. dim = -1 if len(shape) == 1 else -2
  579. dim = len(shape) + dim if dim < 0 else dim
  580. shard_size = math.ceil(shape[dim] / world_size)
  581. start = self.rank * shard_size
  582. end = min(start + shard_size, shape[dim])
  583. shape[dim] = end - start
  584. return tuple(shape)
  585. def update_module_attributes(self, module: nn.Module):
  586. # If we gather the output, the output dimension of the module is not sharded, so no need to update out_features.
  587. # Otherwise, we need to update out_features to reflect the sharded dimension.
  588. if not self.gather_output and hasattr(module, "out_features"):
  589. module.out_features = self.get_expected_sharded_shape((module.out_features,))[0]
  590. class ReplicatedWithGradAllReduce(TensorParallelLayer):
  591. """
  592. Replicated parameter with gradient all-reduce.
  593. For parameters like q_norm/k_norm that sit between colwise and rowwise
  594. layers. The parameter is replicated (not sharded), but its gradient
  595. accumulates from local heads only in TP mode. This class registers a
  596. backward hook to all-reduce the parameter gradient.
  597. """
  598. def _prepare_input_fn(self, mod, inputs, device_mesh):
  599. return inputs
  600. def _prepare_output_fn(self, mod, outputs, device_mesh):
  601. return outputs
  602. def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
  603. return param[...].to(device=device, dtype=dtype)
  604. def prepare_module_tp(self, module, device_mesh, **kwargs):
  605. # Use a module-level backward hook (not param.register_hook) because parameters are replaced during weight loading after this method runs.
  606. # Module hooks survive parameter replacement.
  607. def _backward_hook(mod, grad_input, grad_output, mesh=device_mesh):
  608. for param in mod.parameters():
  609. if param.grad is not None:
  610. all_reduce_forward(param.grad, mesh)
  611. module.register_full_backward_hook(_backward_hook)
  612. class MlaKvAProjParallel(TensorParallelLayer):
  613. """
  614. For MLA attention used in DeepSeek-V2 style models (deepseek_v2, longcat_flash, glm_moe_dsa, glm4_moe_lite):
  615. kv_a_proj_with_mqa output is [kv_lora_rank + qk_rope_head_dim] (can have different naming but important thing
  616. to understand is that it is split)
  617. Example below (from modeling_longcat_flash.py):
  618. kv_a_proj_with_mqa
  619. |
  620. split
  621. / \
  622. k_pass k_rot <-- "bypasses kv_b_proj"
  623. | | (goes straight to attention,
  624. kv_a_layernorm | never touches kv_b_proj)
  625. | |
  626. kv_b_proj |
  627. (colwise) |
  628. | |
  629. k_pass k_rot
  630. \\ /
  631. cat
  632. |
  633. key_states
  634. k_pass is passed to kv_b_proj (colwise) which has built-in all_reduce_backward so we don't have a partial gradient for it.
  635. However, k_rot goes straight to attention, never touches kv_b_proj. So we need to average gradient across all ranks otherwise we only get gradient for one rank (partial gradient).
  636. """
  637. def _prepare_output_fn(self, mod, output, device_mesh):
  638. if not hasattr(mod.config, "qk_rope_head_dim"):
  639. raise AttributeError(
  640. f"Config for {type(mod).__name__} does not have `qk_rope_head_dim`. "
  641. "MlaKvAProjParallel requires `qk_rope_head_dim` to be defined in the model config. "
  642. "Please add it to the model's config or update the TP plan mapping."
  643. )
  644. rope_dim = mod.config.qk_rope_head_dim
  645. pass_output, rope_output = output.split([output.shape[-1] - rope_dim, rope_dim], dim=-1)
  646. rope_output = all_reduce_backward(rope_output, device_mesh)
  647. return torch.cat([pass_output, rope_output], dim=-1)
  648. def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
  649. return param[...].to(device=device, dtype=dtype)
  650. def prepare_module_tp(self, module, device_mesh, config=None, **kwargs):
  651. module.config = config
  652. distribute_module(module, device_mesh, output_fn=self._prepare_output_fn)
  653. class RowwiseParallel(TensorParallelLayer):
  654. """
  655. Row-wise parallel: weight is sharded on dim -1 (input features).
  656. Forward: input (optionally split) -> output partial -> all-reduce to replicate.
  657. Args:
  658. split_input: If True, splits replicated input before matmul. Use when input
  659. comes from a non-parallelizable operation (chunk/slice).
  660. Default False (expects pre-sharded input from colwise layer).
  661. """
  662. def __init__(self, split_input: bool = False, **kwargs):
  663. super().__init__(**kwargs)
  664. self.split_input = split_input
  665. def _prepare_input_fn(self, mod, inputs, device_mesh):
  666. if hasattr(mod, "bias") and mod.bias is not None:
  667. mod._bias = mod.bias
  668. mod.bias = None
  669. input_tensor = inputs[0] if inputs else inputs
  670. if self.split_input:
  671. # Input is replicated, split it to match sharded weight
  672. return split(input_tensor, device_mesh)
  673. return input_tensor
  674. def _prepare_output_fn(self, mod, outputs, device_mesh):
  675. outputs = all_reduce_forward(outputs, device_mesh)
  676. if hasattr(mod, "_bias") and mod._bias is not None:
  677. outputs = outputs + mod._bias
  678. return outputs
  679. def shard_tensor(
  680. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  681. ) -> torch.Tensor:
  682. # If only 1 dim, it should not be sharded (usually it's a `bias`)
  683. dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
  684. if dim == 1:
  685. parameter = param[...]
  686. else:
  687. parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
  688. return parameter.to(device=device, dtype=dtype)
  689. def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
  690. # 1D tensors (bias) are NOT sharded in rowwise
  691. if len(full_shape) == 1:
  692. return tuple(full_shape)
  693. world_size = self.device_mesh.size()
  694. shape = list(full_shape)
  695. dim = -1
  696. dim = len(shape) + dim if dim < 0 else dim
  697. shard_size = math.ceil(shape[dim] / world_size)
  698. start = self.rank * shard_size
  699. end = min(start + shard_size, shape[dim])
  700. shape[dim] = end - start
  701. return tuple(shape)
  702. def update_module_attributes(self, module: nn.Module):
  703. if hasattr(module, "in_features"):
  704. # To fall in the 2D case in get_expected_sharded_shape,
  705. # otherwise it will be treated as 1D and not sharded
  706. shape = (1, module.in_features)
  707. module.in_features = self.get_expected_sharded_shape(shape)[1]
  708. class PackedColwiseParallel(ColwiseParallel):
  709. """Packed column-wise parallel for fused weights like gate_up_proj."""
  710. def shard_tensor(
  711. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  712. ) -> torch.Tensor:
  713. # If only 1 dim, shard this one (usually it's a `bias`)
  714. dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
  715. if dim == 1:
  716. parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
  717. else:
  718. expected_shape = self.get_expected_sharded_shape(self.empty_param.shape)
  719. if dim < len(expected_shape):
  720. # Input is unpacked (e.g., gate_proj that will be concatenated to gate_up_proj)
  721. # Use regular tensor shard - concatenation will happen after
  722. parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
  723. else:
  724. # Input is already packed, use packed sharding
  725. parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
  726. return parameter.to(device=device, dtype=dtype)
  727. class PackedRowwiseParallel(RowwiseParallel):
  728. """Packed row-wise parallel for fused weights like gate_up_proj."""
  729. def shard_tensor(
  730. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  731. ) -> torch.Tensor:
  732. # If only 1 dim, it should not be sharded (usually it's a `bias`)
  733. dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
  734. if dim == 1:
  735. parameter = param[...]
  736. else:
  737. # Check if input tensor is unpacked (shape mismatch with expected packed size)
  738. # This happens when using MergeModulelist + Concatenate for fused weights like gate_up_proj
  739. param_shape = param.shape if isinstance(param, torch.Tensor) else param.get_shape()
  740. expected_packed_dim = self.empty_param.shape[-1] if self.empty_param.dim() >= 1 else 0
  741. actual_dim = param_shape[-1] if len(param_shape) >= 1 else 0
  742. if actual_dim < expected_packed_dim:
  743. # Input is unpacked, use regular tensor shard
  744. parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
  745. else:
  746. # Input is already packed, use packed sharding
  747. parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
  748. return parameter.to(device=device, dtype=dtype)
  749. class EmbeddingParallel(TensorParallelLayer):
  750. """EmbeddingParallel: shards embedding table, handles masked lookups for vocab parallelism."""
  751. def __init__(self, *, embedding_dim_sharding: int = 0, **kwargs):
  752. super().__init__(**kwargs)
  753. self.embedding_dim_sharding = embedding_dim_sharding
  754. def _prepare_input_fn(self, mod, inputs, device_mesh):
  755. input_tensor = inputs[0] if inputs else inputs
  756. # For vocab-parallel (dim 0), we need to handle masking and offsetting
  757. if self.embedding_dim_sharding == 0:
  758. rank = device_mesh.get_local_rank()
  759. # Get vocab range for this rank
  760. # Use weight.shape[0] to get the actual local (sharded) size, not num_embeddings
  761. # which may not be updated after sharding
  762. per_partition_size = mod.weight.shape[0]
  763. vocab_start_index = rank * per_partition_size
  764. vocab_end_index = vocab_start_index + per_partition_size
  765. # Build mask for out-of-vocabulary tokens
  766. input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
  767. mod._input_mask = input_mask
  768. # Offset input to local indices and mask invalid ones
  769. masked_input = input_tensor.clone() - vocab_start_index
  770. masked_input[input_mask] = 0 # Set to valid local index
  771. return masked_input
  772. return input_tensor
  773. def _prepare_output_fn(self, mod, outputs, device_mesh):
  774. # For vocab-parallel (dim 0), zero out embeddings for out-of-range tokens before all-reduce
  775. if self.embedding_dim_sharding == 0 and hasattr(mod, "_input_mask"):
  776. input_mask = mod._input_mask
  777. # Use multiplication instead of in-place assignment to preserve gradients
  778. mask_expanded = input_mask.unsqueeze(-1).expand_as(outputs)
  779. outputs = outputs * (~mask_expanded).to(outputs.dtype)
  780. del mod._input_mask
  781. return all_reduce_forward(outputs, device_mesh)
  782. def shard_tensor(
  783. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  784. ) -> torch.Tensor:
  785. # If only 1 dim, shard this one (usually it's a `bias`)
  786. dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
  787. if dim == 1:
  788. parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
  789. else:
  790. parameter = get_tensor_shard(
  791. param,
  792. self.empty_param,
  793. self.device_mesh,
  794. self.rank,
  795. self.embedding_dim_sharding,
  796. )
  797. return parameter.to(device=device, dtype=dtype)
  798. def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
  799. world_size = self.device_mesh.size()
  800. shape = list(full_shape)
  801. # EmbeddingParallel shards on self.embedding_dim_sharding (default 0)
  802. # 1D tensors (bias) shard on dim -1
  803. dim = -1 if len(shape) == 1 else self.embedding_dim_sharding
  804. dim = len(shape) + dim if dim < 0 else dim
  805. shard_size = math.ceil(shape[dim] / world_size)
  806. start = self.rank * shard_size
  807. end = min(start + shard_size, shape[dim])
  808. shape[dim] = end - start
  809. return tuple(shape)
  810. def update_module_attributes(self, module: nn.Module):
  811. if hasattr(module, "num_embeddings") and self.embedding_dim_sharding == 0:
  812. module.num_embeddings = self.get_expected_sharded_shape((module.num_embeddings,))[0]
  813. if hasattr(module, "embedding_dim") and self.embedding_dim_sharding == 1:
  814. module.embedding_dim = self.get_expected_sharded_shape((module.embedding_dim,))[0]
  815. class SequenceParallel(TensorParallelLayer):
  816. """
  817. Sequence Parallel: input/output sharded on sequence dimension.
  818. Weights are replicated.
  819. """
  820. def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
  821. super().__init__(**kwargs)
  822. self.sequence_dim = sequence_dim
  823. def _prepare_input_fn(self, mod, inputs, device_mesh):
  824. input_tensor = inputs[0] if inputs else inputs
  825. # For sequence parallel, input is sharded on sequence dim
  826. # All-gather for the layer, then reduce-scatter after
  827. return all_gather(input_tensor, device_mesh)
  828. def _prepare_output_fn(self, mod, outputs, device_mesh):
  829. return reduce_scatter(outputs, device_mesh)
  830. def shard_tensor(
  831. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  832. ) -> torch.Tensor:
  833. return param[...].to(device=device, dtype=dtype)
  834. class GroupedGemmParallel(TensorParallelLayer):
  835. """
  836. Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
  837. """
  838. def __init__(self, **kwargs):
  839. super().__init__(**kwargs)
  840. def shard_tensor(
  841. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  842. ) -> torch.Tensor:
  843. global_num_experts = self.empty_param.shape[0]
  844. if global_num_experts % self.device_mesh.size() != 0:
  845. raise ValueError(
  846. f"Global number of experts must be divisible by number of devices: {global_num_experts} % {self.device_mesh.size()} != 0"
  847. )
  848. local_num_experts = global_num_experts // self.device_mesh.size()
  849. shard_size = local_num_experts
  850. if isinstance(device, torch.device):
  851. device = device.index if device.index is not None else 0
  852. start = device * shard_size
  853. end = (device + 1) * shard_size
  854. # special case we don't "shard" just send this entire tensor to the correct rank.
  855. shape = param.get_shape() if not isinstance(param, torch.Tensor) else param.shape
  856. if tensor_idx is not None and start <= tensor_idx < end:
  857. # this tensor does need to be materialized on this device:
  858. return param[:].to(device=device)
  859. elif tensor_idx is None: # a bias or a weight, but already merged
  860. return param[start:end].to(device=device, dtype=dtype)
  861. elif len(shape) >= 1 and tensor_idx is not None:
  862. return None
  863. else: # bias case
  864. return param[:].to(device=device, dtype=dtype)
  865. def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
  866. # GroupedGemm shards on dim 0 (experts dimension)
  867. world_size = self.device_mesh.size()
  868. shape = list(full_shape)
  869. local_num_experts = shape[0] // world_size
  870. shape[0] = local_num_experts
  871. return tuple(shape)
  872. def update_module_attributes(self, module: nn.Module):
  873. if hasattr(module, "num_experts"):
  874. module.num_experts = self.get_expected_sharded_shape((module.num_experts,))[0]
  875. class RouterParallel(TensorParallelLayer):
  876. """
  877. Allows to reshape the router scores to support running expert parallel.
  878. """
  879. def __init__(self, **kwargs):
  880. super().__init__(**kwargs)
  881. def _prepare_input_fn(self, mod, inputs, device_mesh):
  882. return inputs[0] if inputs else inputs
  883. def _prepare_output_fn(self, mod, outputs, device_mesh):
  884. """
  885. Imagine if you had 4 tokens, top_k = 4, and 128experts.
  886. With EP = 8. The num_local_expert should be 128/8 = 16
  887. Imagine router_indices being:
  888. [ 52, 42, 119, 67],
  889. [102, 89, 61, 40],
  890. [ 82, 103, 4, 34],
  891. [ 93, 23, 109, 11],
  892. then you can map which rank should be getting which values
  893. [3, 2, 7, 4],
  894. [6, 5, 3, 2],
  895. [5, 6, 0, 2],
  896. [5, 1, 6, 0],
  897. Thus for say rank 0, you fill with 16 (num_local_expert) the index tensor
  898. [ 16, 16, 16, 16],
  899. [ 16, 16, 16, 16],
  900. [ 16, 16, 4, 16],
  901. [ 16, 16, 16, 11],
  902. This works well. For another rank you need to make sure you round to num_local_expert
  903. because the next operation will one hot encode the router index vector.
  904. This allows us to know directly which local expert is hit.
  905. Similarly the scores are indexed with something created form
  906. router_indices.
  907. The kinda naive training loop that we use for device_map "auto" uses a similar logic.
  908. Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates.
  909. Mask invalid indices with num_local_expert for one-hot encoding, so the computes will skip the masking index.
  910. """
  911. ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
  912. if mod.num_experts % ep_size != 0:
  913. raise ValueError(
  914. f"The number of experts must be divisible by number of ep_size: {mod.num_experts} % {ep_size} != 0"
  915. )
  916. num_local_experts = mod.num_experts // ep_size
  917. router_logits, router_scores, router_indices = outputs
  918. router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_scores)
  919. router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
  920. router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
  921. # As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
  922. if num_local_experts > 1:
  923. router_indices = torch.fmod(router_indices, num_local_experts)
  924. else:
  925. router_indices = router_indices.masked_fill(router_indices > 0, 0).masked_fill(router_indices < 0, -1)
  926. router_indices = router_indices.masked_fill(router_indices == -1, num_local_experts)
  927. return router_logits, router_scores, router_indices
  928. def shard_tensor(
  929. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  930. ) -> torch.Tensor:
  931. return param[...].to(device=device, dtype=dtype)
  932. class MoeTensorParalellExperts(TensorParallelLayer):
  933. """
  934. Note: For tensor parallel, the MoEExpertsParallel TP layer handles gradient sync:
  935. - all_reduce_backward on hidden_states (for colwise gate_up_proj gradient)
  936. - all_reduce_backward on top_k_weights (for router gradient)
  937. - all_reduce_forward on output (for partial expert outputs)
  938. """
  939. def __init__(self, **kwargs):
  940. super().__init__(**kwargs)
  941. def _prepare_input_fn(self, mod, inputs, device_mesh):
  942. # inputs = (hidden_states, top_k_index, top_k_weights)
  943. hidden_states = inputs[0]
  944. top_k_index = inputs[1]
  945. top_k_weights = inputs[2]
  946. # all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient
  947. hidden_states = all_reduce_backward(hidden_states, device_mesh)
  948. # all_reduce_backward on routing weights for correct router gradient
  949. # This is needed because ∂L/∂routing_weights = ∂L/∂output * partial_expert_output
  950. # and partial_expert_output is different on each GPU before all-reduce
  951. top_k_weights = all_reduce_backward(top_k_weights, device_mesh)
  952. return (hidden_states, top_k_index, top_k_weights)
  953. def _prepare_output_fn(self, mod, outputs, device_mesh):
  954. # all_reduce_forward to sum partial expert outputs across GPUs
  955. return all_reduce_forward(outputs, device_mesh)
  956. def shard_tensor(
  957. self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
  958. ) -> torch.Tensor:
  959. # This class doesn't shard tensors - sharding is handled by packed_colwise/rowwise
  960. # on the individual weight tensors (gate_up_proj/down_proj)
  961. return param[...].to(device=device, dtype=dtype)
  962. class MoeIdentityExpertParallel(TensorParallelLayer):
  963. """
  964. TP class for zero/identity experts in MoE layers.
  965. Under TP, the parent MoeTensorParalellExperts does all_reduce_forward (sum)
  966. on the expert module output. Identity experts produce the same output on
  967. every rank, so the sum gives world_size * output. This class divides the
  968. input by world_size to compensate.
  969. """
  970. def _prepare_input_fn(self, mod, inputs, device_mesh):
  971. input_tensor = inputs[0] if inputs else inputs
  972. # TODO(fmom): when 2D-device mesh, need to select a //-ism axis to divide the input tensor by.
  973. return input_tensor / device_mesh.size()
  974. def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
  975. return param[...].to(device=device, dtype=dtype)
  976. def prepare_module_tp(self, module, device_mesh, **kwargs):
  977. distribute_module(module, device_mesh, input_fn=self._prepare_input_fn)
  978. class ParallelInterface(GeneralInterface):
  979. # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
  980. # a new instance is created (in order to locally override a given entry)
  981. _global_mapping = (
  982. {
  983. "embedding_rowwise": EmbeddingParallel(embedding_dim_sharding=0),
  984. "embedding_colwise": EmbeddingParallel(embedding_dim_sharding=1),
  985. "colwise_gather_output": ColwiseParallel(gather_output=True),
  986. "colwise": ColwiseParallel(),
  987. "rowwise": RowwiseParallel(),
  988. "rowwise_split_input": RowwiseParallel(split_input=True),
  989. "packed_colwise": PackedColwiseParallel(),
  990. "packed_rowwise": PackedRowwiseParallel(),
  991. "sequence_parallel": SequenceParallel(),
  992. "grouped_gemm": GroupedGemmParallel(),
  993. "ep_router": RouterParallel(),
  994. "moe_tp_experts": MoeTensorParalellExperts(),
  995. "moe_identity_expert": MoeIdentityExpertParallel(),
  996. "replicated_with_grad_allreduce": ReplicatedWithGradAllReduce(),
  997. "mla_kv_a_proj": MlaKvAProjParallel(),
  998. }
  999. if is_torch_available() and _torch_distributed_available
  1000. else {}
  1001. )
  1002. # Map plan names to sharding dimensions for weights
  1003. # For weights: colwise shards dim -2, rowwise shards dim -1
  1004. # For embedding: rowwise shards dim 0 (vocab), colwise shards dim -2 (hidden)
  1005. plan_to_weight_dim: dict[str, int | None] = {
  1006. "colwise": -2,
  1007. "colwise_gather_output": -2,
  1008. "packed_colwise": -2,
  1009. "rowwise": -1,
  1010. "rowwise_split_input": -1,
  1011. "packed_rowwise": -1,
  1012. "embedding_rowwise": 0,
  1013. "embedding_colwise": 1,
  1014. "sequence_parallel": None,
  1015. "replicated_with_grad_allreduce": None,
  1016. "mla_kv_a_proj": None,
  1017. }
  1018. # Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced)
  1019. plan_to_bias_dim: dict[str, int | None] = {
  1020. "colwise": -1,
  1021. "colwise_gather_output": -1,
  1022. "packed_colwise": -1,
  1023. "rowwise": None,
  1024. "rowwise_split_input": None,
  1025. "packed_rowwise": None,
  1026. "embedding_rowwise": None,
  1027. "embedding_colwise": None,
  1028. "sequence_parallel": None,
  1029. "replicated_with_grad_allreduce": None,
  1030. "mla_kv_a_proj": None,
  1031. }
  1032. @classmethod
  1033. def register_plan_to_weight_dim(cls, key: str, value: int | None):
  1034. cls.plan_to_weight_dim[key] = value
  1035. @classmethod
  1036. def register_plan_to_bias_dim(cls, key: str, value: int | None):
  1037. cls.plan_to_bias_dim[key] = value
  1038. ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
  1039. # =============================================================================
  1040. # High-Level API Functions
  1041. # =============================================================================
  1042. def gather_full_tensor(
  1043. local_tensor: torch.Tensor, shard_dim: int, device_mesh: dist.device_mesh.DeviceMesh
  1044. ) -> torch.Tensor:
  1045. """
  1046. All-gather a sharded tensor along the specified dimension to reconstruct the full tensor.
  1047. Args:
  1048. local_tensor: The local shard of the tensor on this rank
  1049. shard_dim: The dimension along which the tensor was sharded
  1050. device_mesh: The device mesh for distributed communication
  1051. Returns:
  1052. The full reconstructed tensor (same on all ranks)
  1053. """
  1054. world_size = device_mesh.size()
  1055. # In case of TP+DP configuration, the TP group should be used for gathering, not the full DP group
  1056. process_group = device_mesh.get_group("tp") if "tp" in (device_mesh.mesh_dim_names or {}) else None
  1057. # Normalize negative dimension
  1058. if shard_dim < 0:
  1059. shard_dim = local_tensor.ndim + shard_dim
  1060. # Gather all shards
  1061. gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
  1062. dist.all_gather(gathered_tensors, local_tensor.contiguous(), group=process_group)
  1063. # Concatenate along the shard dimension
  1064. return torch.cat(gathered_tensors, dim=shard_dim)
  1065. def gather_state_dict_for_save(
  1066. state_dict: dict[str, torch.Tensor],
  1067. tp_plan: dict[str, str],
  1068. device_mesh,
  1069. tp_size: int,
  1070. ) -> dict[str, torch.Tensor]:
  1071. """
  1072. Gather sharded tensors to reconstruct full tensors for saving.
  1073. This function all-gathers each sharded tensor along its shard dimension
  1074. to reconstruct the full unsharded tensor for checkpoint saving.
  1075. Args:
  1076. state_dict: The model state dict with local sharded tensors
  1077. tp_plan: The tensor parallel plan mapping layer patterns to shard styles
  1078. device_mesh: The device mesh for distributed communication
  1079. tp_size: The tensor parallel world size
  1080. Returns:
  1081. State dict with full (gathered) tensors
  1082. """
  1083. # Use the global mappings from ParallelInterface (can be extended by users)
  1084. plan_to_weight_dim = ALL_PARALLEL_STYLES.plan_to_weight_dim
  1085. plan_to_bias_dim = ALL_PARALLEL_STYLES.plan_to_bias_dim
  1086. result = {}
  1087. for key, tensor in state_dict.items():
  1088. # Find the matching TP plan for this parameter
  1089. param_name = key.rsplit(".", 1)[0] if "." in key else key
  1090. param_type = key.rsplit(".", 1)[1] if "." in key else None
  1091. generic_param_name = re.sub(r"\d+", "*", param_name)
  1092. # Also check the full key for nn.Parameter (e.g., MoE experts without .weight suffix)
  1093. generic_full_key = re.sub(r"\d+", "*", key)
  1094. # Check if this parameter has a TP plan
  1095. current_plan = None
  1096. if generic_full_key in tp_plan:
  1097. # Full key match (e.g., "model.layers.*.mlp.experts.gate_up_proj" for MoE experts)
  1098. current_plan = tp_plan[generic_full_key]
  1099. elif generic_param_name in tp_plan:
  1100. current_plan = tp_plan[generic_param_name]
  1101. elif "." in generic_param_name:
  1102. parent_param_name = generic_param_name.rsplit(".", 1)[0]
  1103. if parent_param_name in tp_plan:
  1104. current_plan = tp_plan[parent_param_name]
  1105. if current_plan is None or current_plan not in plan_to_weight_dim:
  1106. # Not sharded, keep as-is
  1107. result[key] = tensor
  1108. continue
  1109. # Determine sharding dimension based on param type
  1110. if param_type == "bias":
  1111. shard_dim = plan_to_bias_dim.get(current_plan)
  1112. else:
  1113. shard_dim = plan_to_weight_dim.get(current_plan)
  1114. if shard_dim is None:
  1115. # Replicated, keep as-is
  1116. result[key] = tensor
  1117. continue
  1118. # Gather full tensor and handle packed weights repacking
  1119. full_tensor = gather_full_tensor(tensor, shard_dim, device_mesh)
  1120. if current_plan in ("packed_colwise", "packed_rowwise"):
  1121. full_tensor = repack_weights(full_tensor, shard_dim, tp_size, 2)
  1122. result[key] = full_tensor.contiguous()
  1123. return result
  1124. def add_tensor_parallel_hooks_to_module(
  1125. model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None
  1126. ):
  1127. r"""
  1128. This function is called in `PretrainedModel.post_init()`. It is responsible of adding hooks
  1129. to the modules of the `model`, based on the `PretrainedModel._tp_plan`.
  1130. This is the place where we add the `pre_forward` and `post_forwards` hooks. These are defined
  1131. for each `TensorParallelLayer` as `_prepare_input_fn` and `_prepare_output_fn`.
  1132. """
  1133. if current_module_plan is not None:
  1134. tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
  1135. try:
  1136. tp_layer.prepare_module_tp(module, device_mesh, config=model.config)
  1137. except NotImplementedError as e:
  1138. print(
  1139. f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
  1140. )
  1141. module._hf_tp_plan = current_module_plan
  1142. module._hf_device_mesh = device_mesh
  1143. module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
  1144. def shard_and_distribute_module(
  1145. model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
  1146. ):
  1147. r"""
  1148. This function is called in `from_pretrained` when loading a model's checkpoints.
  1149. It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
  1150. All process run this function, so they just load the partition of the tensor that they require.
  1151. Main uses cases:
  1152. - column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
  1153. - packed layers: you slice the weights, then shard like above
  1154. - custom operation:
  1155. - you want to add an all-gather at the end of a local layer.
  1156. - you want to have a layer that is isolated from the rest of the world (because torch.DTensor does not work well with `.view` for instance)
  1157. """
  1158. param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
  1159. tp_plan = model.tp_plan or {}
  1160. module_to_tp = model.get_submodule(param_name)
  1161. rank = int(rank)
  1162. current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
  1163. if dist.get_rank() == 0:
  1164. if current_shard_plan is None:
  1165. logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.")
  1166. else:
  1167. logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}")
  1168. if current_shard_plan is not None:
  1169. try:
  1170. tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
  1171. tp_layer.empty_param = empty_param
  1172. tp_layer.device_mesh = device_mesh
  1173. tp_layer.rank = rank
  1174. param = tp_layer.shard_tensor(param, tensor_idx=None, dtype=param_casting_dtype, device=rank)
  1175. if is_contiguous:
  1176. param = param.contiguous()
  1177. except NotImplementedError as e:
  1178. print(
  1179. f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
  1180. )
  1181. else:
  1182. param = param[:].to(param_casting_dtype)
  1183. # SUPER IMPORTANT we have to use setattr
  1184. # otherwise loading is crazy slow
  1185. if not isinstance(param, torch.nn.Parameter):
  1186. param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
  1187. setattr(module_to_tp, param_type, param)
  1188. tp_layer.update_module_attributes(module_to_tp)
  1189. return param
  1190. def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
  1191. """
  1192. Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied.
  1193. """
  1194. if tp_plan is None:
  1195. return
  1196. generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
  1197. unsharded_layers = set(generic_keys)
  1198. unused_rules = tp_plan.copy()
  1199. for key in generic_keys:
  1200. param_name = key.rsplit(".", 1)[0] if "." in key else key
  1201. generic_param_name = re.sub(r"\d+", "*", param_name)
  1202. if generic_param_name in tp_plan:
  1203. unused_rules.pop(generic_param_name, None)
  1204. unsharded_layers.discard(key)
  1205. elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
  1206. unused_rules.pop(parent_param_name, None)
  1207. unsharded_layers.discard(key)
  1208. if len(unused_rules) > 0:
  1209. logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
  1210. if len(unsharded_layers) > 0:
  1211. logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
  1212. def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
  1213. """Distribute a model according to the TP plan."""
  1214. model._tp_size = tp_size
  1215. model._device_mesh = device_mesh
  1216. if distributed_config is not None:
  1217. if isinstance(distributed_config, dict):
  1218. distributed_config = DistributedConfig.from_dict(distributed_config)
  1219. model.config.distributed_config = distributed_config
  1220. # Set the new requested tp_plan on the model
  1221. if isinstance(tp_plan, dict):
  1222. model.tp_plan = tp_plan
  1223. model_plan = model.tp_plan
  1224. if model_plan is not None and _torch_distributed_available:
  1225. for v in model_plan.values():
  1226. if v not in ALL_PARALLEL_STYLES:
  1227. raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
  1228. for name, module in model.named_modules():
  1229. if not getattr(module, "_is_hooked", False):
  1230. plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model_plan, is_weight=False)
  1231. add_tensor_parallel_hooks_to_module(
  1232. model=model,
  1233. module=module,
  1234. tp_plan=model_plan,
  1235. layer_name="",
  1236. current_module_plan=plan,
  1237. device_mesh=device_mesh,
  1238. )
  1239. module._is_hooked = True
  1240. return model