mixer_seq_simple.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # Copyright (c) 2023, Albert Gu, Tri Dao.
  2. import math
  3. from functools import partial
  4. import json
  5. import os
  6. import copy
  7. from collections import namedtuple
  8. import torch
  9. import torch.nn as nn
  10. from mamba_ssm.models.config_mamba import MambaConfig
  11. from mamba_ssm.modules.mamba_simple import Mamba
  12. from mamba_ssm.modules.mamba2 import Mamba2
  13. from mamba_ssm.modules.mha import MHA
  14. from mamba_ssm.modules.mlp import GatedMLP
  15. from mamba_ssm.modules.block import Block
  16. from mamba_ssm.utils.generation import GenerationMixin
  17. from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
  18. try:
  19. from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
  20. except ImportError:
  21. RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
  22. def create_block(
  23. d_model,
  24. d_intermediate,
  25. ssm_cfg=None,
  26. attn_layer_idx=None,
  27. attn_cfg=None,
  28. norm_epsilon=1e-5,
  29. rms_norm=False,
  30. residual_in_fp32=False,
  31. fused_add_norm=False,
  32. layer_idx=None,
  33. device=None,
  34. dtype=None,
  35. ):
  36. if ssm_cfg is None:
  37. ssm_cfg = {}
  38. if attn_layer_idx is None:
  39. attn_layer_idx = []
  40. if attn_cfg is None:
  41. attn_cfg = {}
  42. factory_kwargs = {"device": device, "dtype": dtype}
  43. if layer_idx not in attn_layer_idx:
  44. # Create a copy of the config to modify
  45. ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
  46. ssm_layer = ssm_cfg.pop("layer", "Mamba1")
  47. if ssm_layer not in ["Mamba1", "Mamba2"]:
  48. raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
  49. mixer_cls = partial(
  50. Mamba2 if ssm_layer == "Mamba2" else Mamba,
  51. layer_idx=layer_idx,
  52. **ssm_cfg,
  53. **factory_kwargs
  54. )
  55. else:
  56. mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
  57. norm_cls = partial(
  58. nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
  59. )
  60. if d_intermediate == 0:
  61. mlp_cls = nn.Identity
  62. else:
  63. mlp_cls = partial(
  64. GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
  65. )
  66. block = Block(
  67. d_model,
  68. mixer_cls,
  69. mlp_cls,
  70. norm_cls=norm_cls,
  71. fused_add_norm=fused_add_norm,
  72. residual_in_fp32=residual_in_fp32,
  73. )
  74. block.layer_idx = layer_idx
  75. return block
  76. # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
  77. def _init_weights(
  78. module,
  79. n_layer,
  80. initializer_range=0.02, # Now only used for embedding layer.
  81. rescale_prenorm_residual=True,
  82. n_residuals_per_layer=1, # Change to 2 if we have MLP
  83. ):
  84. if isinstance(module, nn.Linear):
  85. if module.bias is not None:
  86. if not getattr(module.bias, "_no_reinit", False):
  87. nn.init.zeros_(module.bias)
  88. elif isinstance(module, nn.Embedding):
  89. nn.init.normal_(module.weight, std=initializer_range)
  90. if rescale_prenorm_residual:
  91. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  92. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  93. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  94. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  95. #
  96. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  97. for name, p in module.named_parameters():
  98. if name in ["out_proj.weight", "fc2.weight"]:
  99. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  100. # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
  101. # We need to reinit p since this code could be called multiple times
  102. # Having just p *= scale would repeatedly scale it down
  103. nn.init.kaiming_uniform_(p, a=math.sqrt(5))
  104. with torch.no_grad():
  105. p /= math.sqrt(n_residuals_per_layer * n_layer)
  106. class MixerModel(nn.Module):
  107. def __init__(
  108. self,
  109. d_model: int,
  110. n_layer: int,
  111. d_intermediate: int,
  112. vocab_size: int,
  113. ssm_cfg=None,
  114. attn_layer_idx=None,
  115. attn_cfg=None,
  116. norm_epsilon: float = 1e-5,
  117. rms_norm: bool = False,
  118. initializer_cfg=None,
  119. fused_add_norm=False,
  120. residual_in_fp32=False,
  121. device=None,
  122. dtype=None,
  123. ) -> None:
  124. factory_kwargs = {"device": device, "dtype": dtype}
  125. super().__init__()
  126. self.residual_in_fp32 = residual_in_fp32
  127. self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
  128. # We change the order of residual and layer norm:
  129. # Instead of LN -> Attn / MLP -> Add, we do:
  130. # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
  131. # the main branch (output of MLP / Mixer). The model definition is unchanged.
  132. # This is for performance reason: we can fuse add + layer_norm.
  133. self.fused_add_norm = fused_add_norm
  134. if self.fused_add_norm:
  135. if layer_norm_fn is None or rms_norm_fn is None:
  136. raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
  137. self.layers = nn.ModuleList(
  138. [
  139. create_block(
  140. d_model,
  141. d_intermediate=d_intermediate,
  142. ssm_cfg=ssm_cfg,
  143. attn_layer_idx=attn_layer_idx,
  144. attn_cfg=attn_cfg,
  145. norm_epsilon=norm_epsilon,
  146. rms_norm=rms_norm,
  147. residual_in_fp32=residual_in_fp32,
  148. fused_add_norm=fused_add_norm,
  149. layer_idx=i,
  150. **factory_kwargs,
  151. )
  152. for i in range(n_layer)
  153. ]
  154. )
  155. self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
  156. d_model, eps=norm_epsilon, **factory_kwargs
  157. )
  158. self.apply(
  159. partial(
  160. _init_weights,
  161. n_layer=n_layer,
  162. **(initializer_cfg if initializer_cfg is not None else {}),
  163. n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
  164. )
  165. )
  166. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  167. return {
  168. i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
  169. for i, layer in enumerate(self.layers)
  170. }
  171. def forward(self, input_ids, inference_params=None, **mixer_kwargs):
  172. hidden_states = self.embedding(input_ids)
  173. residual = None
  174. for layer in self.layers:
  175. hidden_states, residual = layer(
  176. hidden_states, residual, inference_params=inference_params, **mixer_kwargs
  177. )
  178. if not self.fused_add_norm:
  179. residual = (hidden_states + residual) if residual is not None else hidden_states
  180. hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
  181. else:
  182. # Set prenorm=False here since we don't need the residual
  183. hidden_states = layer_norm_fn(
  184. hidden_states,
  185. self.norm_f.weight,
  186. self.norm_f.bias,
  187. eps=self.norm_f.eps,
  188. residual=residual,
  189. prenorm=False,
  190. residual_in_fp32=self.residual_in_fp32,
  191. is_rms_norm=isinstance(self.norm_f, RMSNorm)
  192. )
  193. return hidden_states
  194. class MambaLMHeadModel(nn.Module, GenerationMixin):
  195. def __init__(
  196. self,
  197. config: MambaConfig,
  198. initializer_cfg=None,
  199. device=None,
  200. dtype=None,
  201. ) -> None:
  202. self.config = config
  203. d_model = config.d_model
  204. n_layer = config.n_layer
  205. d_intermediate = config.d_intermediate
  206. vocab_size = config.vocab_size
  207. ssm_cfg = config.ssm_cfg
  208. attn_layer_idx = config.attn_layer_idx
  209. attn_cfg = config.attn_cfg
  210. rms_norm = config.rms_norm
  211. residual_in_fp32 = config.residual_in_fp32
  212. fused_add_norm = config.fused_add_norm
  213. pad_vocab_size_multiple = config.pad_vocab_size_multiple
  214. factory_kwargs = {"device": device, "dtype": dtype}
  215. super().__init__()
  216. if vocab_size % pad_vocab_size_multiple != 0:
  217. vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
  218. self.backbone = MixerModel(
  219. d_model=d_model,
  220. n_layer=n_layer,
  221. d_intermediate=d_intermediate,
  222. vocab_size=vocab_size,
  223. ssm_cfg=ssm_cfg,
  224. attn_layer_idx=attn_layer_idx,
  225. attn_cfg=attn_cfg,
  226. rms_norm=rms_norm,
  227. initializer_cfg=initializer_cfg,
  228. fused_add_norm=fused_add_norm,
  229. residual_in_fp32=residual_in_fp32,
  230. **factory_kwargs,
  231. )
  232. self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
  233. # Initialize weights and apply final processing
  234. self.apply(
  235. partial(
  236. _init_weights,
  237. n_layer=n_layer,
  238. **(initializer_cfg if initializer_cfg is not None else {}),
  239. )
  240. )
  241. self.tie_weights()
  242. def tie_weights(self):
  243. if self.config.tie_embeddings:
  244. self.lm_head.weight = self.backbone.embedding.weight
  245. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  246. return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
  247. def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
  248. """
  249. "position_ids" is just to be compatible with Transformer generation. We don't use it.
  250. num_last_tokens: if > 0, only return the logits for the last n tokens
  251. """
  252. hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
  253. if num_last_tokens > 0:
  254. hidden_states = hidden_states[:, -num_last_tokens:]
  255. lm_logits = self.lm_head(hidden_states)
  256. CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
  257. return CausalLMOutput(logits=lm_logits)
  258. @classmethod
  259. def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
  260. config_data = load_config_hf(pretrained_model_name)
  261. config = MambaConfig(**config_data)
  262. model = cls(config, device=device, dtype=dtype, **kwargs)
  263. model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
  264. return model
  265. def save_pretrained(self, save_directory):
  266. """
  267. Minimal implementation of save_pretrained for MambaLMHeadModel.
  268. Save the model and its configuration file to a directory.
  269. """
  270. # Ensure save_directory exists
  271. os.makedirs(save_directory, exist_ok=True)
  272. # Save the model's state_dict
  273. model_path = os.path.join(save_directory, 'pytorch_model.bin')
  274. torch.save(self.state_dict(), model_path)
  275. # Save the configuration of the model
  276. config_path = os.path.join(save_directory, 'config.json')
  277. with open(config_path, 'w') as f:
  278. json.dump(self.config.__dict__, f, indent=4)