tensor_parallel.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright (c) 2024, Tri Dao.
  2. # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch import Tensor
  8. from torch.distributed import ProcessGroup
  9. from mamba_ssm.utils.torch import custom_bwd, custom_fwd
  10. from einops import rearrange
  11. from mamba_ssm.distributed.distributed_utils import (
  12. all_gather_raw,
  13. all_reduce,
  14. all_reduce_raw,
  15. reduce_scatter,
  16. reduce_scatter_raw,
  17. )
  18. class ParallelLinearFunc(torch.autograd.Function):
  19. @staticmethod
  20. @custom_fwd
  21. def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
  22. """
  23. If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
  24. with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
  25. """
  26. ctx.compute_weight_gradient = weight.requires_grad
  27. ctx.process_group = process_group
  28. ctx.sequence_parallel = sequence_parallel
  29. if torch.is_autocast_enabled():
  30. x = x.to(dtype=torch.get_autocast_gpu_dtype())
  31. x = x.contiguous()
  32. if process_group is not None and sequence_parallel:
  33. # We want to kick off the all_gather early, before weight dtype conversion
  34. total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
  35. else:
  36. total_x = x
  37. if torch.is_autocast_enabled():
  38. weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
  39. bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
  40. weight = weight.contiguous()
  41. if process_group is not None and sequence_parallel:
  42. handle_x.wait()
  43. batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
  44. batch_dim = batch_shape.numel()
  45. # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
  46. output = F.linear(total_x, weight, bias)
  47. if ctx.compute_weight_gradient:
  48. ctx.save_for_backward(x, weight)
  49. else:
  50. ctx.save_for_backward(weight)
  51. return output
  52. @staticmethod
  53. @custom_bwd
  54. def backward(ctx, grad_output):
  55. grad_output = grad_output.contiguous()
  56. process_group = ctx.process_group
  57. sequence_parallel = ctx.sequence_parallel
  58. if ctx.compute_weight_gradient:
  59. x, weight = ctx.saved_tensors
  60. if process_group is not None and sequence_parallel:
  61. total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
  62. else:
  63. total_x = x
  64. else:
  65. (weight,) = ctx.saved_tensors
  66. total_x = None
  67. batch_shape = grad_output.shape[:-1]
  68. batch_dim = batch_shape.numel()
  69. grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
  70. if ctx.needs_input_grad[0]:
  71. grad_input = F.linear(grad_output, weight.t())
  72. grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
  73. if process_group is not None:
  74. reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
  75. grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
  76. else:
  77. grad_input = None
  78. if ctx.needs_input_grad[1]:
  79. assert ctx.compute_weight_gradient
  80. if process_group is not None and sequence_parallel:
  81. handle_x.wait()
  82. grad_weight = torch.einsum(
  83. "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
  84. )
  85. else:
  86. grad_weight = None
  87. grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
  88. if process_group is not None and ctx.needs_input_grad[0]:
  89. handle_grad_input.wait()
  90. return grad_input, grad_weight, grad_bias, None, None
  91. def parallel_linear_func(
  92. x: Tensor,
  93. weight: Tensor,
  94. bias: Optional[Tensor] = None,
  95. process_group: Optional[ProcessGroup] = None,
  96. sequence_parallel: bool = True,
  97. ):
  98. return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
  99. class ColumnParallelLinear(nn.Linear):
  100. def __init__(
  101. self,
  102. in_features: int,
  103. out_features: int,
  104. process_group: ProcessGroup,
  105. bias: bool = True,
  106. sequence_parallel=True,
  107. multiple_of=1,
  108. device=None,
  109. dtype=None,
  110. ) -> None:
  111. world_size = torch.distributed.get_world_size(process_group)
  112. if out_features % multiple_of:
  113. raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
  114. multiple = out_features // multiple_of
  115. # We want to split @multiple across world_size, but it could be an uneven split
  116. div = multiple // world_size
  117. mod = multiple % world_size
  118. # The first @mod ranks get @div + 1 copies, the rest get @div copies
  119. local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
  120. super().__init__(
  121. in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
  122. )
  123. self.process_group = process_group
  124. self.sequence_parallel = sequence_parallel
  125. def forward(self, x):
  126. # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
  127. # we do an all_gather of x before doing the matmul.
  128. # If not, then the input is already gathered.
  129. return parallel_linear_func(
  130. x,
  131. self.weight,
  132. self.bias,
  133. process_group=self.process_group,
  134. sequence_parallel=self.sequence_parallel,
  135. )
  136. class RowParallelLinear(nn.Linear):
  137. def __init__(
  138. self,
  139. in_features: int,
  140. out_features: int,
  141. process_group: ProcessGroup,
  142. bias: bool = True,
  143. sequence_parallel=True,
  144. multiple_of=1,
  145. device=None,
  146. dtype=None,
  147. ) -> None:
  148. world_size = torch.distributed.get_world_size(process_group)
  149. rank = torch.distributed.get_rank(process_group)
  150. if in_features % multiple_of:
  151. raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
  152. multiple = in_features // multiple_of
  153. # We want to split @multiple across world_size, but it could be an uneven split
  154. div = multiple // world_size
  155. mod = multiple % world_size
  156. # The first @mod ranks get @div + 1 copies, the rest get @div copies
  157. local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
  158. # Only rank 0 will have bias
  159. super().__init__(
  160. local_multiple * multiple_of,
  161. out_features,
  162. bias=bias and rank == 0,
  163. device=device,
  164. dtype=dtype,
  165. )
  166. self.process_group = process_group
  167. self.sequence_parallel = sequence_parallel
  168. def forward(self, x):
  169. """
  170. We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
  171. a reduce_scatter of the result.
  172. """
  173. out = parallel_linear_func(x, self.weight, self.bias)
  174. reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
  175. return reduce_fn(out, self.process_group)
  176. class VocabParallelEmbedding(nn.Embedding):
  177. def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
  178. self.process_group = process_group
  179. if process_group is not None:
  180. world_size = torch.distributed.get_world_size(process_group)
  181. if num_embeddings % world_size != 0:
  182. raise ValueError(
  183. f"num_embeddings ({num_embeddings}) must be divisible by "
  184. f"world_size ({world_size})"
  185. )
  186. if world_size > 1 and padding_idx is not None:
  187. raise RuntimeError("ParallelEmbedding does not support padding_idx")
  188. else:
  189. world_size = 1
  190. super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
  191. def forward(self, input: Tensor) -> Tensor:
  192. if self.process_group is None:
  193. return super().forward(input)
  194. else:
  195. rank = torch.distributed.get_rank(self.process_group)
  196. vocab_size = self.num_embeddings
  197. vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
  198. # Create a mask of valid vocab ids (1 means it needs to be masked).
  199. input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
  200. input = input - vocab_start_index
  201. input[input_ids_mask] = 0
  202. embeddings = super().forward(input)
  203. embeddings[input_ids_mask] = 0.0
  204. return embeddings
  205. class ColumnParallelEmbedding(nn.Embedding):
  206. def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
  207. self.process_group = process_group
  208. if process_group is not None:
  209. world_size = torch.distributed.get_world_size(process_group)
  210. if embedding_dim % world_size != 0:
  211. raise ValueError(
  212. f"embedding_dim ({embedding_dim}) must be divisible by "
  213. f"world_size ({world_size})"
  214. )
  215. else:
  216. world_size = 1
  217. super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
  218. class ParallelEmbeddings(nn.Module):
  219. def __init__(
  220. self,
  221. embed_dim,
  222. vocab_size,
  223. max_position_embeddings,
  224. process_group,
  225. padding_idx=None,
  226. sequence_parallel=True,
  227. device=None,
  228. dtype=None,
  229. ):
  230. """
  231. If max_position_embeddings <= 0, there's no position embeddings
  232. """
  233. factory_kwargs = {"device": device, "dtype": dtype}
  234. super().__init__()
  235. self.process_group = process_group
  236. self.sequence_parallel = sequence_parallel
  237. self.word_embeddings = VocabParallelEmbedding(
  238. vocab_size,
  239. embed_dim,
  240. padding_idx=padding_idx,
  241. process_group=process_group,
  242. **factory_kwargs,
  243. )
  244. self.max_position_embeddings = max_position_embeddings
  245. if self.max_position_embeddings > 0:
  246. self.position_embeddings = ColumnParallelEmbedding(
  247. max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
  248. )
  249. def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
  250. """
  251. input_ids: (batch, seqlen)
  252. position_ids: (batch, seqlen)
  253. """
  254. batch_size, seqlen = input_ids.shape
  255. world_size = torch.distributed.get_world_size(self.process_group)
  256. embeddings = self.word_embeddings(input_ids)
  257. if self.max_position_embeddings > 0:
  258. if position_ids is None:
  259. position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
  260. position_embeddings = self.position_embeddings(position_ids)
  261. if world_size <= 1:
  262. embeddings = embeddings + position_embeddings
  263. else:
  264. partition_dim = self.position_embeddings.embedding_dim
  265. rank = torch.distributed.get_rank(self.process_group)
  266. embeddings[
  267. ..., rank * partition_dim : (rank + 1) * partition_dim
  268. ] += position_embeddings
  269. if combine_batch_seqlen_dim:
  270. embeddings = rearrange(embeddings, "b s d -> (b s) d")
  271. reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
  272. return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)