| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525 |
- # Copyright 2024 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import annotations
- import math
- import operator
- import os
- import re
- from functools import reduce
- from ..distributed import DistributedConfig
- from ..utils import is_torch_greater_or_equal, logging
- from ..utils.generic import GeneralInterface
- from ..utils.import_utils import is_torch_available
- if is_torch_available():
- import torch
- import torch.distributed as dist
- from torch import nn
- # Cache this result has it's a C FFI call which can be pretty time-consuming
- _torch_distributed_available = torch.distributed.is_available()
- logger = logging.get_logger(__name__)
- def initialize_tensor_parallelism(
- tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
- ):
- r"""
- Sets up the device mesh and initialized the backend for tensor parallelism.
- This function is called when the model is loaded and the TP plan is set to 'auto'.
- """
- if tp_size is not None and tp_plan is None:
- raise ValueError("tp_plan has to be set when tp_size is passed.")
- if tp_plan is not None and device_map is not None:
- raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
- if device_mesh is None:
- if not is_torch_greater_or_equal("2.5"):
- raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
- # Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
- device_type = torch._C._get_accelerator().type
- if device_type == "mps":
- raise RuntimeError("Tensor parallelism is not supported on MPS devices.")
- current_device = getattr(torch, device_type)
- if not torch.distributed.is_initialized():
- try:
- rank = int(os.environ["RANK"])
- local_rank = int(os.environ["LOCAL_RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
- backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl", "neuron": "neuron"}
- backend = backend_map.get(device_type)
- torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
- current_device = getattr(torch, device_type)
- if device_type != "cpu":
- current_device.set_device(local_rank)
- except Exception as e:
- raise OSError(
- "We tried to initialize torch.distributed for you, but it failed. Make "
- "sure you init torch distributed in your script to use `tp_plan`."
- ) from e
- if device_type != "cpu":
- current_device.set_device(int(os.environ["LOCAL_RANK"]))
- index = current_device.current_device()
- tp_device = torch.device(device_type, index)
- device_map = tp_device
- else:
- tp_device = torch.device(device_type)
- device_map = device_type or {}
- tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
- device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
- else:
- if device_mesh.ndim > 1:
- if "tp" not in device_mesh.mesh_dim_names:
- raise ValueError(
- "When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
- "Please provide a valid `device_mesh`."
- )
- device_mesh = device_mesh["tp"]
- tp_size = device_mesh.size()
- device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
- return device_map, device_mesh, tp_size
- def replace_layer_number_by_wildcard(name: str) -> str:
- """
- Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
- a dot (`.`) and the end of the string.
- This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
- numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
- """
- return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
- def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
- """
- Get the TP style for a parameter from the TP plan.
- The TP plan is a dictionary that maps parameter names to TP styles.
- The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
- The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
- not parent classes for `post_init` calls
- """
- generic_param_name = replace_layer_number_by_wildcard(parameter_name)
- if generic_param_name in tp_plan:
- return tp_plan[generic_param_name]
- elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
- return tp_plan[module_name]
- return None
- # =============================================================================
- # Tensor Sharding Utilities
- # =============================================================================
- if is_torch_available():
- str_to_dtype = {
- "BOOL": torch.bool,
- "U8": torch.uint8,
- "I8": torch.int8,
- "I16": torch.int16,
- "F16": torch.float16,
- "BF16": torch.bfloat16,
- "I32": torch.int32,
- "F32": torch.float32,
- "F64": torch.float64,
- "I64": torch.int64,
- "F8_E4M3": torch.float8_e4m3fn,
- }
- def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
- """
- Convert block count or proportions to block sizes.
- This function accepts
- - The number of blocks (int), in which case the block size is
- total_size//blocks; or
- - A list of block sizes (list[int]).
- In the second case, if sum(blocks) < total_size, the ratios between
- the block sizes will be preserved. For instance, if blocks is
- [2, 1, 1] and total_size is 1024, the returned block sizes are
- [512, 256, 256].
- """
- if isinstance(blocks, list):
- total_blocks = sum(blocks)
- assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
- part_size = total_size // total_blocks
- return [part_size * block for block in blocks]
- else:
- assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
- single_size = total_size // blocks
- return [single_size] * blocks
- def get_packed_weights(param, empty_param, device_mesh, rank, dim):
- """
- When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
- So if you have: gate_proj ( 16, 5120, 8190)
- and up_proj ( 16, 5120, 8190)
- packed as gate_up_proj ( 16, 5120, 2 * 8190)
- And you shard along the last dimension, you need to interleave the gate and up values:
- Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly.
- Let's take TP_size = 4 for an example:
- Packed tensor `gate_up_proj`
- ---------------------------------------------------------------
- [ G0 G1 G2 G3 | G4 G5 G6 G7 | ... | U0 U1 U2 U3 | U4 U5 U6 U7 | ... ]
- ↑─────────────↑ ↑─────────────↑ ↑─────────────↑ ↑─────────────↑
- Gate Slice 0 Gate Slice 1 Up Slice 0 Up Slice 1
- Explanation:
- - The first half of the tensor (left of the center) holds the gate_proj values.
- - The second half (right of the center) holds the up_proj values.
- - For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity.
- - Each shard receives one slice from the gate part and the corresponding slice from the up part.
- For instance:
- • Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ]
- • Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ]
- • … and so on.
- This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism.
- """
- slice_ = param
- total_size = empty_param.shape[dim]
- world_size = device_mesh.size()
- block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2)
- tensors_slices = []
- block_offset = 0
- for block_size in block_sizes:
- shard_block_size = block_size // world_size
- start = rank * shard_block_size
- stop = (rank + 1) * shard_block_size
- tensors_slices += range(block_offset + start, block_offset + stop)
- block_offset += block_size
- slice_dtype = slice_.get_dtype()
- # Handle F8_E4M3 dtype by converting to float16 before slicing
- # Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
- casted = False
- if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2":
- slice_ = slice_[...].to(torch.float16)
- casted = True
- if dim == 0:
- tensor = slice_[tensors_slices, ...]
- elif dim == 1 or dim == -2:
- tensor = slice_[:, tensors_slices, ...]
- elif dim == 2 or dim == -1:
- tensor = slice_[..., tensors_slices]
- else:
- raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
- if casted:
- return tensor
- else:
- return tensor.to(str_to_dtype[slice_dtype])
- def repack_weights(
- packed_parameter: torch.Tensor,
- sharded_dim: int, # The dimension index in the global tensor that was sharded
- world_size: int,
- num_blocks: int = 2,
- ) -> torch.Tensor:
- """
- Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
- For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
- DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
- along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
- This is an inverse operation to get_packed_weights.
- Args:
- reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
- sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
- world_size: The tensor parallel world size.
- num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
- Returns:
- The reordered tensor in canonical packed format.
- """
- if num_blocks != 2:
- raise ValueError(
- "Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
- )
- actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
- total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
- original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
- shard_chunk_size = original_block_size_on_dim // world_size
- prefix_shape = packed_parameter.shape[:actual_sharded_dim]
- suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
- tensor_view = packed_parameter.view(
- *prefix_shape,
- world_size,
- num_blocks,
- shard_chunk_size,
- *suffix_shape,
- )
- # Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
- # This groups all chunks of G together, then all chunks of U together.
- # Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
- # Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
- # Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
- axis_ws_abs = len(prefix_shape)
- axis_npp_abs = len(prefix_shape) + 1
- permute_order = list(range(tensor_view.ndim))
- permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
- tensor_permuted = tensor_view.permute(*permute_order)
- # Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
- # The final shape should be the same as reconstructed_tensor.
- final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
- return final_ordered_tensor
- def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int | None = None):
- """
- Generalized tensor sharding across a multi-dimensional device mesh.
- Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
- Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
- `Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
- such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
- Case (1)
- empty_param (16, 5120, 8190)
- dim 0
- device_mesh.size() 4
- rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
- rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
- rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
- rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
- Case (2)
- empty_param (16, 5120, 8190)
- dim 0
- device_mesh.size() 14
- rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
- rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
- rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
- rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
- rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
- rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
- rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
- rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
- rank 8 gets (0, 5120, 8190)
- rank 9 gets (0, 5120, 8190)
- rank 10 gets (0, 5120, 8190)
- rank 11 gets (0, 5120, 8190)
- rank 12 gets (0, 5120, 8190)
- rank 13 gets (0, 5120, 8190)
- Case (3)
- empty_param (16, 5120, 8190)
- dim 0
- device_mesh.size() 3
- rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
- rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
- rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
- In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
- Args:
- param (torch.Tensor): The tensor to shard.
- empty_param (torch.Tensor): A tensor used for shape reference.
- device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh.
- rank (int): Global rank of the current process/device.
- dim (int): Dimension along which to shard the tensor.
- """
- param_dim = empty_param.ndim
- mesh_shape = device_mesh.shape
- world_size = reduce(operator.mul, mesh_shape)
- # Get param shape: works for both torch.Tensor and safetensors TensorInfo
- param_shape = list(param.shape) if isinstance(param, torch.Tensor) else param.get_shape()
- if dim < 0:
- dim = param_dim + dim
- if empty_param.dim() == 3 and dim == 1 and len(param_shape) == 2:
- dim = 0
- elif empty_param.dim() == 3 and dim == 2 and len(param_shape) == 2:
- dim = 1
- shard_size = math.ceil(param_shape[dim] / world_size)
- start = rank * shard_size
- end = min(start + shard_size, param_shape[dim])
- if dim >= param_dim:
- raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
- if rank >= world_size:
- raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
- # we have the full tensor not 1 part of it.
- # in that case, we just assume that the weight was properly saved
- # and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
- # to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
- # here we take care of potential chunking / layer split / layer chunking.
- # The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
- # actually we still shard dim=0 does not change
- # so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
- # tensor on a certain device (with the input tensor_index)
- if tensor_idx is not None and empty_param.dim() == 3 and dim == 0 and len(param_shape) == 2:
- # special case we don't "shard" just send this entire tensor to the correct rank.
- if start <= tensor_idx < end:
- # this tensor does need to be materialized on this device:
- return param[:]
- else:
- return torch.empty([], dtype=torch.int64, device=rank)
- slice_indices = [slice(None)] * len(param_shape)
- if start < param_shape[dim]:
- slice_indices[dim] = slice(start, end)
- param = param[tuple(slice_indices)]
- if isinstance(param, list): # TODO handle the modulelist case!
- param = [p[:] for p in param]
- return param
- param_shape[dim] = 0
- return torch.empty(tuple(param_shape), dtype=torch.int64) # empty allocates memory....
- def _split_along_last_dim(x, world_size):
- """Split tensor along last dimension into world_size chunks."""
- return torch.chunk(x, world_size, dim=-1)
- # =============================================================================
- # Distributed Communication Primitives
- # =============================================================================
- #
- # Naming convention:
- # - Functions describe their FORWARD behavior
- # - Backward behavior is the "conjugate" operation for gradient flow
- #
- # Available operations:
- # ┌────────────────────┬─────────────────────┬─────────────────────┐
- # │ Function │ Forward │ Backward │
- # ├────────────────────┼─────────────────────┼─────────────────────┤
- # │ all_reduce │ all-reduce (sum) │ identity │
- # │ all_reduce_backward│ identity │ all-reduce (sum) │
- # │ all_gather │ all-gather │ split (local chunk) │
- # │ split │ split (local chunk) │ all-gather │
- # │ reduce_scatter │ reduce-scatter │ all-gather │
- # └────────────────────┴─────────────────────┴─────────────────────┘
- # ===================
- class _AllReduceBackward(torch.autograd.Function):
- """Identity forward, all-reduce backward. Used before colwise layers (f in Megatron)."""
- @staticmethod
- def forward(ctx, x, device_mesh):
- ctx.device_mesh = device_mesh
- return x
- @staticmethod
- def backward(ctx, grad_output):
- device_mesh = ctx.device_mesh
- if device_mesh.size() == 1:
- return grad_output, None
- grad_output = grad_output.contiguous()
- dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
- return grad_output, None
- class _AllReduceForward(torch.autograd.Function):
- """All-reduce forward, identity backward. Used after rowwise layers (g in Megatron)."""
- @staticmethod
- def forward(ctx, x, device_mesh):
- if device_mesh.size() == 1:
- return x
- dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
- return x
- @staticmethod
- def backward(ctx, grad_output):
- return grad_output, None
- class _AllGather(torch.autograd.Function):
- """All-gather forward, split backward. Gathers sharded outputs."""
- @staticmethod
- def forward(ctx, x, device_mesh):
- ctx.device_mesh = device_mesh
- world_size = device_mesh.size()
- if world_size == 1:
- return x
- last_dim = x.dim() - 1
- rank = device_mesh.get_local_rank()
- group = device_mesh.get_group()
- x = x.contiguous()
- tensor_list = [torch.empty_like(x) for _ in range(world_size)]
- tensor_list[rank] = x
- dist.all_gather(tensor_list, x, group=group)
- return torch.cat(tensor_list, dim=last_dim).contiguous()
- @staticmethod
- def backward(ctx, grad_output):
- device_mesh = ctx.device_mesh
- world_size = device_mesh.size()
- if world_size == 1:
- return grad_output, None
- rank = device_mesh.get_local_rank()
- chunks = _split_along_last_dim(grad_output, world_size)
- return chunks[rank].contiguous(), None
- class _Split(torch.autograd.Function):
- """Split forward, all-gather backward. Scatters replicated input."""
- @staticmethod
- def forward(ctx, x, device_mesh):
- ctx.device_mesh = device_mesh
- world_size = device_mesh.size()
- if world_size == 1:
- return x
- rank = device_mesh.get_local_rank()
- chunks = _split_along_last_dim(x, world_size)
- return chunks[rank].contiguous()
- @staticmethod
- def backward(ctx, grad_output):
- device_mesh = ctx.device_mesh
- world_size = device_mesh.size()
- if world_size == 1:
- return grad_output, None
- last_dim = grad_output.dim() - 1
- rank = device_mesh.get_local_rank()
- group = device_mesh.get_group()
- grad_output = grad_output.contiguous()
- tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
- tensor_list[rank] = grad_output
- dist.all_gather(tensor_list, grad_output, group=group)
- return torch.cat(tensor_list, dim=last_dim).contiguous(), None
- class _ReduceScatter(torch.autograd.Function):
- """Reduce-scatter forward, all-gather backward. For sequence parallel."""
- @staticmethod
- def forward(ctx, x, device_mesh):
- ctx.device_mesh = device_mesh
- world_size = device_mesh.size()
- if world_size == 1:
- return x
- last_dim = x.dim() - 1
- group = device_mesh.get_group()
- input_chunks = list(x.chunk(world_size, dim=last_dim))
- output_shape = list(x.shape)
- output_shape[last_dim] //= world_size
- output = torch.empty(output_shape, dtype=x.dtype, device=x.device)
- dist.reduce_scatter(output, input_chunks, op=dist.ReduceOp.SUM, group=group)
- return output
- @staticmethod
- def backward(ctx, grad_output):
- device_mesh = ctx.device_mesh
- world_size = device_mesh.size()
- if world_size == 1:
- return grad_output, None
- last_dim = grad_output.dim() - 1
- rank = device_mesh.get_local_rank()
- group = device_mesh.get_group()
- grad_output = grad_output.contiguous()
- tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
- tensor_list[rank] = grad_output
- dist.all_gather(tensor_list, grad_output, group=group)
- return torch.cat(tensor_list, dim=last_dim).contiguous(), None
- # =============================================================================
- # Convenience wrappers
- # =============================================================================
- def all_reduce_backward(x, device_mesh):
- """Identity forward, all-reduce backward. Use before colwise layers."""
- return _AllReduceBackward.apply(x, device_mesh)
- def all_reduce_forward(x, device_mesh):
- """All-reduce forward, identity backward. Use after rowwise layers."""
- return _AllReduceForward.apply(x, device_mesh)
- def all_gather(x, device_mesh):
- """All-gather forward, split backward."""
- return _AllGather.apply(x, device_mesh)
- def split(x, device_mesh):
- """Split forward, all-gather backward."""
- return _Split.apply(x, device_mesh)
- def reduce_scatter(x, device_mesh):
- """Reduce-scatter forward, all-gather backward."""
- return _ReduceScatter.apply(x, device_mesh)
- def distribute_module(
- module: nn.Module,
- device_mesh=None,
- input_fn=None,
- output_fn=None,
- ) -> nn.Module:
- """
- Copy pasted from torch's function but we remove the communications (partitioning)
- as well as buffer registering that is similarly not efficient.
- """
- if input_fn is not None:
- module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
- if output_fn is not None:
- module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
- return module
- class TensorParallelLayer:
- """General tensor parallel layer for transformers"""
- device_mesh = None
- rank = None
- empty_param = None
- def __init__(self, device_mesh=None, rank=None, empty_param=None):
- self.rank = rank
- self.device_mesh = device_mesh
- self.empty_param = empty_param
- def _prepare_input_fn(self, mod, inputs, device_mesh):
- raise NotImplementedError
- def _prepare_output_fn(self, mod, outputs, device_mesh):
- raise NotImplementedError
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- raise NotImplementedError
- def prepare_module_tp(self, module: nn.Module, device_mesh, **kwargs) -> nn.Module:
- distribute_module(
- module,
- device_mesh,
- self._prepare_input_fn,
- self._prepare_output_fn,
- )
- def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
- """
- Compute the expected shape after TP sharding for a given full shape.
- Args:
- full_shape: The full (unsharded) parameter shape
- Returns:
- The expected sharded shape for this rank
- """
- # Default: no sharding, return full shape
- return tuple(full_shape)
- def update_module_attributes(self, module: nn.Module):
- """
- Update module attributes (e.g. in_features, out_features) to reflect sharded dimensions.
- Args:
- module: The module to update
- Returns:
- None, update the module in-place
- """
- pass
- class ColwiseParallel(TensorParallelLayer):
- """
- Column-wise parallel: weight is sharded on dim -2 (output features).
- Forward: input replicated -> output sharded on last dim.
- If gather_output=True, output is all-gathered to produce full tensor.
- """
- def __init__(self, gather_output: bool = False, **kwargs):
- super().__init__(**kwargs)
- self.gather_output = gather_output
- def _prepare_input_fn(self, mod, inputs, device_mesh):
- input_tensor = inputs[0] if inputs else inputs
- return all_reduce_backward(input_tensor, device_mesh)
- def _prepare_output_fn(self, mod, outputs, device_mesh):
- if self.gather_output:
- return all_gather(outputs, device_mesh)
- return outputs
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- # If only 1 dim, shard this one (usually it's a `bias`)
- dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
- if dim == 1:
- parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
- else:
- parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
- return parameter.to(device=device, dtype=dtype)
- def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
- world_size = self.device_mesh.size()
- shape = list(full_shape)
- # Colwise shards dim -2, but 1D tensors (bias) shard on dim -1
- dim = -1 if len(shape) == 1 else -2
- dim = len(shape) + dim if dim < 0 else dim
- shard_size = math.ceil(shape[dim] / world_size)
- start = self.rank * shard_size
- end = min(start + shard_size, shape[dim])
- shape[dim] = end - start
- return tuple(shape)
- def update_module_attributes(self, module: nn.Module):
- # If we gather the output, the output dimension of the module is not sharded, so no need to update out_features.
- # Otherwise, we need to update out_features to reflect the sharded dimension.
- if not self.gather_output and hasattr(module, "out_features"):
- module.out_features = self.get_expected_sharded_shape((module.out_features,))[0]
- class ReplicatedWithGradAllReduce(TensorParallelLayer):
- """
- Replicated parameter with gradient all-reduce.
- For parameters like q_norm/k_norm that sit between colwise and rowwise
- layers. The parameter is replicated (not sharded), but its gradient
- accumulates from local heads only in TP mode. This class registers a
- backward hook to all-reduce the parameter gradient.
- """
- def _prepare_input_fn(self, mod, inputs, device_mesh):
- return inputs
- def _prepare_output_fn(self, mod, outputs, device_mesh):
- return outputs
- def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
- return param[...].to(device=device, dtype=dtype)
- def prepare_module_tp(self, module, device_mesh, **kwargs):
- # Use a module-level backward hook (not param.register_hook) because parameters are replaced during weight loading after this method runs.
- # Module hooks survive parameter replacement.
- def _backward_hook(mod, grad_input, grad_output, mesh=device_mesh):
- for param in mod.parameters():
- if param.grad is not None:
- all_reduce_forward(param.grad, mesh)
- module.register_full_backward_hook(_backward_hook)
- class MlaKvAProjParallel(TensorParallelLayer):
- """
- For MLA attention used in DeepSeek-V2 style models (deepseek_v2, longcat_flash, glm_moe_dsa, glm4_moe_lite):
- kv_a_proj_with_mqa output is [kv_lora_rank + qk_rope_head_dim] (can have different naming but important thing
- to understand is that it is split)
- Example below (from modeling_longcat_flash.py):
- kv_a_proj_with_mqa
- |
- split
- / \
- k_pass k_rot <-- "bypasses kv_b_proj"
- | | (goes straight to attention,
- kv_a_layernorm | never touches kv_b_proj)
- | |
- kv_b_proj |
- (colwise) |
- | |
- k_pass k_rot
- \\ /
- cat
- |
- key_states
- k_pass is passed to kv_b_proj (colwise) which has built-in all_reduce_backward so we don't have a partial gradient for it.
- However, k_rot goes straight to attention, never touches kv_b_proj. So we need to average gradient across all ranks otherwise we only get gradient for one rank (partial gradient).
- """
- def _prepare_output_fn(self, mod, output, device_mesh):
- if not hasattr(mod.config, "qk_rope_head_dim"):
- raise AttributeError(
- f"Config for {type(mod).__name__} does not have `qk_rope_head_dim`. "
- "MlaKvAProjParallel requires `qk_rope_head_dim` to be defined in the model config. "
- "Please add it to the model's config or update the TP plan mapping."
- )
- rope_dim = mod.config.qk_rope_head_dim
- pass_output, rope_output = output.split([output.shape[-1] - rope_dim, rope_dim], dim=-1)
- rope_output = all_reduce_backward(rope_output, device_mesh)
- return torch.cat([pass_output, rope_output], dim=-1)
- def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
- return param[...].to(device=device, dtype=dtype)
- def prepare_module_tp(self, module, device_mesh, config=None, **kwargs):
- module.config = config
- distribute_module(module, device_mesh, output_fn=self._prepare_output_fn)
- class RowwiseParallel(TensorParallelLayer):
- """
- Row-wise parallel: weight is sharded on dim -1 (input features).
- Forward: input (optionally split) -> output partial -> all-reduce to replicate.
- Args:
- split_input: If True, splits replicated input before matmul. Use when input
- comes from a non-parallelizable operation (chunk/slice).
- Default False (expects pre-sharded input from colwise layer).
- """
- def __init__(self, split_input: bool = False, **kwargs):
- super().__init__(**kwargs)
- self.split_input = split_input
- def _prepare_input_fn(self, mod, inputs, device_mesh):
- if hasattr(mod, "bias") and mod.bias is not None:
- mod._bias = mod.bias
- mod.bias = None
- input_tensor = inputs[0] if inputs else inputs
- if self.split_input:
- # Input is replicated, split it to match sharded weight
- return split(input_tensor, device_mesh)
- return input_tensor
- def _prepare_output_fn(self, mod, outputs, device_mesh):
- outputs = all_reduce_forward(outputs, device_mesh)
- if hasattr(mod, "_bias") and mod._bias is not None:
- outputs = outputs + mod._bias
- return outputs
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- # If only 1 dim, it should not be sharded (usually it's a `bias`)
- dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
- if dim == 1:
- parameter = param[...]
- else:
- parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
- return parameter.to(device=device, dtype=dtype)
- def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
- # 1D tensors (bias) are NOT sharded in rowwise
- if len(full_shape) == 1:
- return tuple(full_shape)
- world_size = self.device_mesh.size()
- shape = list(full_shape)
- dim = -1
- dim = len(shape) + dim if dim < 0 else dim
- shard_size = math.ceil(shape[dim] / world_size)
- start = self.rank * shard_size
- end = min(start + shard_size, shape[dim])
- shape[dim] = end - start
- return tuple(shape)
- def update_module_attributes(self, module: nn.Module):
- if hasattr(module, "in_features"):
- # To fall in the 2D case in get_expected_sharded_shape,
- # otherwise it will be treated as 1D and not sharded
- shape = (1, module.in_features)
- module.in_features = self.get_expected_sharded_shape(shape)[1]
- class PackedColwiseParallel(ColwiseParallel):
- """Packed column-wise parallel for fused weights like gate_up_proj."""
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- # If only 1 dim, shard this one (usually it's a `bias`)
- dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
- if dim == 1:
- parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
- else:
- expected_shape = self.get_expected_sharded_shape(self.empty_param.shape)
- if dim < len(expected_shape):
- # Input is unpacked (e.g., gate_proj that will be concatenated to gate_up_proj)
- # Use regular tensor shard - concatenation will happen after
- parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
- else:
- # Input is already packed, use packed sharding
- parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
- return parameter.to(device=device, dtype=dtype)
- class PackedRowwiseParallel(RowwiseParallel):
- """Packed row-wise parallel for fused weights like gate_up_proj."""
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- # If only 1 dim, it should not be sharded (usually it's a `bias`)
- dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
- if dim == 1:
- parameter = param[...]
- else:
- # Check if input tensor is unpacked (shape mismatch with expected packed size)
- # This happens when using MergeModulelist + Concatenate for fused weights like gate_up_proj
- param_shape = param.shape if isinstance(param, torch.Tensor) else param.get_shape()
- expected_packed_dim = self.empty_param.shape[-1] if self.empty_param.dim() >= 1 else 0
- actual_dim = param_shape[-1] if len(param_shape) >= 1 else 0
- if actual_dim < expected_packed_dim:
- # Input is unpacked, use regular tensor shard
- parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
- else:
- # Input is already packed, use packed sharding
- parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
- return parameter.to(device=device, dtype=dtype)
- class EmbeddingParallel(TensorParallelLayer):
- """EmbeddingParallel: shards embedding table, handles masked lookups for vocab parallelism."""
- def __init__(self, *, embedding_dim_sharding: int = 0, **kwargs):
- super().__init__(**kwargs)
- self.embedding_dim_sharding = embedding_dim_sharding
- def _prepare_input_fn(self, mod, inputs, device_mesh):
- input_tensor = inputs[0] if inputs else inputs
- # For vocab-parallel (dim 0), we need to handle masking and offsetting
- if self.embedding_dim_sharding == 0:
- rank = device_mesh.get_local_rank()
- # Get vocab range for this rank
- # Use weight.shape[0] to get the actual local (sharded) size, not num_embeddings
- # which may not be updated after sharding
- per_partition_size = mod.weight.shape[0]
- vocab_start_index = rank * per_partition_size
- vocab_end_index = vocab_start_index + per_partition_size
- # Build mask for out-of-vocabulary tokens
- input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
- mod._input_mask = input_mask
- # Offset input to local indices and mask invalid ones
- masked_input = input_tensor.clone() - vocab_start_index
- masked_input[input_mask] = 0 # Set to valid local index
- return masked_input
- return input_tensor
- def _prepare_output_fn(self, mod, outputs, device_mesh):
- # For vocab-parallel (dim 0), zero out embeddings for out-of-range tokens before all-reduce
- if self.embedding_dim_sharding == 0 and hasattr(mod, "_input_mask"):
- input_mask = mod._input_mask
- # Use multiplication instead of in-place assignment to preserve gradients
- mask_expanded = input_mask.unsqueeze(-1).expand_as(outputs)
- outputs = outputs * (~mask_expanded).to(outputs.dtype)
- del mod._input_mask
- return all_reduce_forward(outputs, device_mesh)
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- # If only 1 dim, shard this one (usually it's a `bias`)
- dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
- if dim == 1:
- parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
- else:
- parameter = get_tensor_shard(
- param,
- self.empty_param,
- self.device_mesh,
- self.rank,
- self.embedding_dim_sharding,
- )
- return parameter.to(device=device, dtype=dtype)
- def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
- world_size = self.device_mesh.size()
- shape = list(full_shape)
- # EmbeddingParallel shards on self.embedding_dim_sharding (default 0)
- # 1D tensors (bias) shard on dim -1
- dim = -1 if len(shape) == 1 else self.embedding_dim_sharding
- dim = len(shape) + dim if dim < 0 else dim
- shard_size = math.ceil(shape[dim] / world_size)
- start = self.rank * shard_size
- end = min(start + shard_size, shape[dim])
- shape[dim] = end - start
- return tuple(shape)
- def update_module_attributes(self, module: nn.Module):
- if hasattr(module, "num_embeddings") and self.embedding_dim_sharding == 0:
- module.num_embeddings = self.get_expected_sharded_shape((module.num_embeddings,))[0]
- if hasattr(module, "embedding_dim") and self.embedding_dim_sharding == 1:
- module.embedding_dim = self.get_expected_sharded_shape((module.embedding_dim,))[0]
- class SequenceParallel(TensorParallelLayer):
- """
- Sequence Parallel: input/output sharded on sequence dimension.
- Weights are replicated.
- """
- def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
- super().__init__(**kwargs)
- self.sequence_dim = sequence_dim
- def _prepare_input_fn(self, mod, inputs, device_mesh):
- input_tensor = inputs[0] if inputs else inputs
- # For sequence parallel, input is sharded on sequence dim
- # All-gather for the layer, then reduce-scatter after
- return all_gather(input_tensor, device_mesh)
- def _prepare_output_fn(self, mod, outputs, device_mesh):
- return reduce_scatter(outputs, device_mesh)
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- return param[...].to(device=device, dtype=dtype)
- class GroupedGemmParallel(TensorParallelLayer):
- """
- Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
- """
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- global_num_experts = self.empty_param.shape[0]
- if global_num_experts % self.device_mesh.size() != 0:
- raise ValueError(
- f"Global number of experts must be divisible by number of devices: {global_num_experts} % {self.device_mesh.size()} != 0"
- )
- local_num_experts = global_num_experts // self.device_mesh.size()
- shard_size = local_num_experts
- if isinstance(device, torch.device):
- device = device.index if device.index is not None else 0
- start = device * shard_size
- end = (device + 1) * shard_size
- # special case we don't "shard" just send this entire tensor to the correct rank.
- shape = param.get_shape() if not isinstance(param, torch.Tensor) else param.shape
- if tensor_idx is not None and start <= tensor_idx < end:
- # this tensor does need to be materialized on this device:
- return param[:].to(device=device)
- elif tensor_idx is None: # a bias or a weight, but already merged
- return param[start:end].to(device=device, dtype=dtype)
- elif len(shape) >= 1 and tensor_idx is not None:
- return None
- else: # bias case
- return param[:].to(device=device, dtype=dtype)
- def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
- # GroupedGemm shards on dim 0 (experts dimension)
- world_size = self.device_mesh.size()
- shape = list(full_shape)
- local_num_experts = shape[0] // world_size
- shape[0] = local_num_experts
- return tuple(shape)
- def update_module_attributes(self, module: nn.Module):
- if hasattr(module, "num_experts"):
- module.num_experts = self.get_expected_sharded_shape((module.num_experts,))[0]
- class RouterParallel(TensorParallelLayer):
- """
- Allows to reshape the router scores to support running expert parallel.
- """
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def _prepare_input_fn(self, mod, inputs, device_mesh):
- return inputs[0] if inputs else inputs
- def _prepare_output_fn(self, mod, outputs, device_mesh):
- """
- Imagine if you had 4 tokens, top_k = 4, and 128experts.
- With EP = 8. The num_local_expert should be 128/8 = 16
- Imagine router_indices being:
- [ 52, 42, 119, 67],
- [102, 89, 61, 40],
- [ 82, 103, 4, 34],
- [ 93, 23, 109, 11],
- then you can map which rank should be getting which values
- [3, 2, 7, 4],
- [6, 5, 3, 2],
- [5, 6, 0, 2],
- [5, 1, 6, 0],
- Thus for say rank 0, you fill with 16 (num_local_expert) the index tensor
- [ 16, 16, 16, 16],
- [ 16, 16, 16, 16],
- [ 16, 16, 4, 16],
- [ 16, 16, 16, 11],
- This works well. For another rank you need to make sure you round to num_local_expert
- because the next operation will one hot encode the router index vector.
- This allows us to know directly which local expert is hit.
- Similarly the scores are indexed with something created form
- router_indices.
- The kinda naive training loop that we use for device_map "auto" uses a similar logic.
- Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates.
- Mask invalid indices with num_local_expert for one-hot encoding, so the computes will skip the masking index.
- """
- ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
- if mod.num_experts % ep_size != 0:
- raise ValueError(
- f"The number of experts must be divisible by number of ep_size: {mod.num_experts} % {ep_size} != 0"
- )
- num_local_experts = mod.num_experts // ep_size
- router_logits, router_scores, router_indices = outputs
- router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_scores)
- router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
- router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
- # As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
- if num_local_experts > 1:
- router_indices = torch.fmod(router_indices, num_local_experts)
- else:
- router_indices = router_indices.masked_fill(router_indices > 0, 0).masked_fill(router_indices < 0, -1)
- router_indices = router_indices.masked_fill(router_indices == -1, num_local_experts)
- return router_logits, router_scores, router_indices
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- return param[...].to(device=device, dtype=dtype)
- class MoeTensorParalellExperts(TensorParallelLayer):
- """
- Note: For tensor parallel, the MoEExpertsParallel TP layer handles gradient sync:
- - all_reduce_backward on hidden_states (for colwise gate_up_proj gradient)
- - all_reduce_backward on top_k_weights (for router gradient)
- - all_reduce_forward on output (for partial expert outputs)
- """
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def _prepare_input_fn(self, mod, inputs, device_mesh):
- # inputs = (hidden_states, top_k_index, top_k_weights)
- hidden_states = inputs[0]
- top_k_index = inputs[1]
- top_k_weights = inputs[2]
- # all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient
- hidden_states = all_reduce_backward(hidden_states, device_mesh)
- # all_reduce_backward on routing weights for correct router gradient
- # This is needed because ∂L/∂routing_weights = ∂L/∂output * partial_expert_output
- # and partial_expert_output is different on each GPU before all-reduce
- top_k_weights = all_reduce_backward(top_k_weights, device_mesh)
- return (hidden_states, top_k_index, top_k_weights)
- def _prepare_output_fn(self, mod, outputs, device_mesh):
- # all_reduce_forward to sum partial expert outputs across GPUs
- return all_reduce_forward(outputs, device_mesh)
- def shard_tensor(
- self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
- ) -> torch.Tensor:
- # This class doesn't shard tensors - sharding is handled by packed_colwise/rowwise
- # on the individual weight tensors (gate_up_proj/down_proj)
- return param[...].to(device=device, dtype=dtype)
- class MoeIdentityExpertParallel(TensorParallelLayer):
- """
- TP class for zero/identity experts in MoE layers.
- Under TP, the parent MoeTensorParalellExperts does all_reduce_forward (sum)
- on the expert module output. Identity experts produce the same output on
- every rank, so the sum gives world_size * output. This class divides the
- input by world_size to compensate.
- """
- def _prepare_input_fn(self, mod, inputs, device_mesh):
- input_tensor = inputs[0] if inputs else inputs
- # TODO(fmom): when 2D-device mesh, need to select a //-ism axis to divide the input tensor by.
- return input_tensor / device_mesh.size()
- def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
- return param[...].to(device=device, dtype=dtype)
- def prepare_module_tp(self, module, device_mesh, **kwargs):
- distribute_module(module, device_mesh, input_fn=self._prepare_input_fn)
- class ParallelInterface(GeneralInterface):
- # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
- # a new instance is created (in order to locally override a given entry)
- _global_mapping = (
- {
- "embedding_rowwise": EmbeddingParallel(embedding_dim_sharding=0),
- "embedding_colwise": EmbeddingParallel(embedding_dim_sharding=1),
- "colwise_gather_output": ColwiseParallel(gather_output=True),
- "colwise": ColwiseParallel(),
- "rowwise": RowwiseParallel(),
- "rowwise_split_input": RowwiseParallel(split_input=True),
- "packed_colwise": PackedColwiseParallel(),
- "packed_rowwise": PackedRowwiseParallel(),
- "sequence_parallel": SequenceParallel(),
- "grouped_gemm": GroupedGemmParallel(),
- "ep_router": RouterParallel(),
- "moe_tp_experts": MoeTensorParalellExperts(),
- "moe_identity_expert": MoeIdentityExpertParallel(),
- "replicated_with_grad_allreduce": ReplicatedWithGradAllReduce(),
- "mla_kv_a_proj": MlaKvAProjParallel(),
- }
- if is_torch_available() and _torch_distributed_available
- else {}
- )
- # Map plan names to sharding dimensions for weights
- # For weights: colwise shards dim -2, rowwise shards dim -1
- # For embedding: rowwise shards dim 0 (vocab), colwise shards dim -2 (hidden)
- plan_to_weight_dim: dict[str, int | None] = {
- "colwise": -2,
- "colwise_gather_output": -2,
- "packed_colwise": -2,
- "rowwise": -1,
- "rowwise_split_input": -1,
- "packed_rowwise": -1,
- "embedding_rowwise": 0,
- "embedding_colwise": 1,
- "sequence_parallel": None,
- "replicated_with_grad_allreduce": None,
- "mla_kv_a_proj": None,
- }
- # Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced)
- plan_to_bias_dim: dict[str, int | None] = {
- "colwise": -1,
- "colwise_gather_output": -1,
- "packed_colwise": -1,
- "rowwise": None,
- "rowwise_split_input": None,
- "packed_rowwise": None,
- "embedding_rowwise": None,
- "embedding_colwise": None,
- "sequence_parallel": None,
- "replicated_with_grad_allreduce": None,
- "mla_kv_a_proj": None,
- }
- @classmethod
- def register_plan_to_weight_dim(cls, key: str, value: int | None):
- cls.plan_to_weight_dim[key] = value
- @classmethod
- def register_plan_to_bias_dim(cls, key: str, value: int | None):
- cls.plan_to_bias_dim[key] = value
- ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
- # =============================================================================
- # High-Level API Functions
- # =============================================================================
- def gather_full_tensor(
- local_tensor: torch.Tensor, shard_dim: int, device_mesh: dist.device_mesh.DeviceMesh
- ) -> torch.Tensor:
- """
- All-gather a sharded tensor along the specified dimension to reconstruct the full tensor.
- Args:
- local_tensor: The local shard of the tensor on this rank
- shard_dim: The dimension along which the tensor was sharded
- device_mesh: The device mesh for distributed communication
- Returns:
- The full reconstructed tensor (same on all ranks)
- """
- world_size = device_mesh.size()
- # In case of TP+DP configuration, the TP group should be used for gathering, not the full DP group
- process_group = device_mesh.get_group("tp") if "tp" in (device_mesh.mesh_dim_names or {}) else None
- # Normalize negative dimension
- if shard_dim < 0:
- shard_dim = local_tensor.ndim + shard_dim
- # Gather all shards
- gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
- dist.all_gather(gathered_tensors, local_tensor.contiguous(), group=process_group)
- # Concatenate along the shard dimension
- return torch.cat(gathered_tensors, dim=shard_dim)
- def gather_state_dict_for_save(
- state_dict: dict[str, torch.Tensor],
- tp_plan: dict[str, str],
- device_mesh,
- tp_size: int,
- ) -> dict[str, torch.Tensor]:
- """
- Gather sharded tensors to reconstruct full tensors for saving.
- This function all-gathers each sharded tensor along its shard dimension
- to reconstruct the full unsharded tensor for checkpoint saving.
- Args:
- state_dict: The model state dict with local sharded tensors
- tp_plan: The tensor parallel plan mapping layer patterns to shard styles
- device_mesh: The device mesh for distributed communication
- tp_size: The tensor parallel world size
- Returns:
- State dict with full (gathered) tensors
- """
- # Use the global mappings from ParallelInterface (can be extended by users)
- plan_to_weight_dim = ALL_PARALLEL_STYLES.plan_to_weight_dim
- plan_to_bias_dim = ALL_PARALLEL_STYLES.plan_to_bias_dim
- result = {}
- for key, tensor in state_dict.items():
- # Find the matching TP plan for this parameter
- param_name = key.rsplit(".", 1)[0] if "." in key else key
- param_type = key.rsplit(".", 1)[1] if "." in key else None
- generic_param_name = re.sub(r"\d+", "*", param_name)
- # Also check the full key for nn.Parameter (e.g., MoE experts without .weight suffix)
- generic_full_key = re.sub(r"\d+", "*", key)
- # Check if this parameter has a TP plan
- current_plan = None
- if generic_full_key in tp_plan:
- # Full key match (e.g., "model.layers.*.mlp.experts.gate_up_proj" for MoE experts)
- current_plan = tp_plan[generic_full_key]
- elif generic_param_name in tp_plan:
- current_plan = tp_plan[generic_param_name]
- elif "." in generic_param_name:
- parent_param_name = generic_param_name.rsplit(".", 1)[0]
- if parent_param_name in tp_plan:
- current_plan = tp_plan[parent_param_name]
- if current_plan is None or current_plan not in plan_to_weight_dim:
- # Not sharded, keep as-is
- result[key] = tensor
- continue
- # Determine sharding dimension based on param type
- if param_type == "bias":
- shard_dim = plan_to_bias_dim.get(current_plan)
- else:
- shard_dim = plan_to_weight_dim.get(current_plan)
- if shard_dim is None:
- # Replicated, keep as-is
- result[key] = tensor
- continue
- # Gather full tensor and handle packed weights repacking
- full_tensor = gather_full_tensor(tensor, shard_dim, device_mesh)
- if current_plan in ("packed_colwise", "packed_rowwise"):
- full_tensor = repack_weights(full_tensor, shard_dim, tp_size, 2)
- result[key] = full_tensor.contiguous()
- return result
- def add_tensor_parallel_hooks_to_module(
- model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None
- ):
- r"""
- This function is called in `PretrainedModel.post_init()`. It is responsible of adding hooks
- to the modules of the `model`, based on the `PretrainedModel._tp_plan`.
- This is the place where we add the `pre_forward` and `post_forwards` hooks. These are defined
- for each `TensorParallelLayer` as `_prepare_input_fn` and `_prepare_output_fn`.
- """
- if current_module_plan is not None:
- tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
- try:
- tp_layer.prepare_module_tp(module, device_mesh, config=model.config)
- except NotImplementedError as e:
- print(
- f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
- )
- module._hf_tp_plan = current_module_plan
- module._hf_device_mesh = device_mesh
- module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
- def shard_and_distribute_module(
- model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
- ):
- r"""
- This function is called in `from_pretrained` when loading a model's checkpoints.
- It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
- All process run this function, so they just load the partition of the tensor that they require.
- Main uses cases:
- - column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
- - packed layers: you slice the weights, then shard like above
- - custom operation:
- - you want to add an all-gather at the end of a local layer.
- - you want to have a layer that is isolated from the rest of the world (because torch.DTensor does not work well with `.view` for instance)
- """
- param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
- tp_plan = model.tp_plan or {}
- module_to_tp = model.get_submodule(param_name)
- rank = int(rank)
- current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
- if dist.get_rank() == 0:
- if current_shard_plan is None:
- logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.")
- else:
- logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}")
- if current_shard_plan is not None:
- try:
- tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
- tp_layer.empty_param = empty_param
- tp_layer.device_mesh = device_mesh
- tp_layer.rank = rank
- param = tp_layer.shard_tensor(param, tensor_idx=None, dtype=param_casting_dtype, device=rank)
- if is_contiguous:
- param = param.contiguous()
- except NotImplementedError as e:
- print(
- f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
- )
- else:
- param = param[:].to(param_casting_dtype)
- # SUPER IMPORTANT we have to use setattr
- # otherwise loading is crazy slow
- if not isinstance(param, torch.nn.Parameter):
- param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
- setattr(module_to_tp, param_type, param)
- tp_layer.update_module_attributes(module_to_tp)
- return param
- def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
- """
- Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied.
- """
- if tp_plan is None:
- return
- generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
- unsharded_layers = set(generic_keys)
- unused_rules = tp_plan.copy()
- for key in generic_keys:
- param_name = key.rsplit(".", 1)[0] if "." in key else key
- generic_param_name = re.sub(r"\d+", "*", param_name)
- if generic_param_name in tp_plan:
- unused_rules.pop(generic_param_name, None)
- unsharded_layers.discard(key)
- elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
- unused_rules.pop(parent_param_name, None)
- unsharded_layers.discard(key)
- if len(unused_rules) > 0:
- logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
- if len(unsharded_layers) > 0:
- logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
- def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
- """Distribute a model according to the TP plan."""
- model._tp_size = tp_size
- model._device_mesh = device_mesh
- if distributed_config is not None:
- if isinstance(distributed_config, dict):
- distributed_config = DistributedConfig.from_dict(distributed_config)
- model.config.distributed_config = distributed_config
- # Set the new requested tp_plan on the model
- if isinstance(tp_plan, dict):
- model.tp_plan = tp_plan
- model_plan = model.tp_plan
- if model_plan is not None and _torch_distributed_available:
- for v in model_plan.values():
- if v not in ALL_PARALLEL_STYLES:
- raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
- for name, module in model.named_modules():
- if not getattr(module, "_is_hooked", False):
- plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model_plan, is_weight=False)
- add_tensor_parallel_hooks_to_module(
- model=model,
- module=module,
- tp_plan=model_plan,
- layer_name="",
- current_module_plan=plan,
- device_mesh=device_mesh,
- )
- module._is_hooked = True
- return model
|