modeling_dbrx.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/dbrx/modular_dbrx.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_dbrx.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Any, Optional
  22. import torch
  23. from torch import nn
  24. from ... import initialization as init
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_func_from_hub
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  36. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  37. from ...utils.output_capturing import capture_outputs
  38. from .configuration_dbrx import DbrxConfig
  39. class DbrxRotaryEmbedding(nn.Module):
  40. inv_freq: torch.Tensor # fix linting for `register_buffer`
  41. def __init__(self, config: DbrxConfig, device=None):
  42. super().__init__()
  43. self.max_seq_len_cached = config.max_position_embeddings
  44. self.original_max_seq_len = config.max_position_embeddings
  45. self.config = config
  46. self.rope_type = self.config.rope_parameters["rope_type"]
  47. rope_init_fn: Callable = self.compute_default_rope_parameters
  48. if self.rope_type != "default":
  49. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  50. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  51. self.register_buffer("inv_freq", inv_freq, persistent=False)
  52. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  53. @staticmethod
  54. def compute_default_rope_parameters(
  55. config: DbrxConfig | None = None,
  56. device: Optional["torch.device"] = None,
  57. seq_len: int | None = None,
  58. ) -> tuple["torch.Tensor", float]:
  59. """
  60. Computes the inverse frequencies according to the original RoPE implementation
  61. Args:
  62. config ([`~transformers.PreTrainedConfig`]):
  63. The model configuration.
  64. device (`torch.device`):
  65. The device to use for initialization of the inverse frequencies.
  66. seq_len (`int`, *optional*):
  67. The current sequence length. Unused for this type of RoPE.
  68. Returns:
  69. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  70. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  71. """
  72. base = config.rope_parameters["rope_theta"]
  73. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  74. attention_factor = 1.0 # Unused in this type of RoPE
  75. # Compute the inverse frequencies
  76. inv_freq = 1.0 / (
  77. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  78. )
  79. return inv_freq, attention_factor
  80. @torch.no_grad()
  81. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  82. def forward(self, x, position_ids):
  83. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  84. position_ids_expanded = position_ids[:, None, :].float()
  85. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  86. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  87. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  88. emb = torch.cat((freqs, freqs), dim=-1)
  89. cos = emb.cos() * self.attention_scaling
  90. sin = emb.sin() * self.attention_scaling
  91. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  92. def rotate_half(x):
  93. """Rotates half the hidden dims of the input."""
  94. x1 = x[..., : x.shape[-1] // 2]
  95. x2 = x[..., x.shape[-1] // 2 :]
  96. return torch.cat((-x2, x1), dim=-1)
  97. @use_kernel_func_from_hub("rotary_pos_emb")
  98. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  99. """Applies Rotary Position Embedding to the query and key tensors.
  100. Args:
  101. q (`torch.Tensor`): The query tensor.
  102. k (`torch.Tensor`): The key tensor.
  103. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  104. sin (`torch.Tensor`): The sine part of the rotary embedding.
  105. unsqueeze_dim (`int`, *optional*, defaults to 1):
  106. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  107. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  108. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  109. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  110. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  111. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  112. Returns:
  113. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  114. """
  115. cos = cos.unsqueeze(unsqueeze_dim)
  116. sin = sin.unsqueeze(unsqueeze_dim)
  117. q_embed = (q * cos) + (rotate_half(q) * sin)
  118. k_embed = (k * cos) + (rotate_half(k) * sin)
  119. return q_embed, k_embed
  120. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  121. """
  122. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  123. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  124. """
  125. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  126. if n_rep == 1:
  127. return hidden_states
  128. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  129. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  130. def eager_attention_forward(
  131. module: nn.Module,
  132. query: torch.Tensor,
  133. key: torch.Tensor,
  134. value: torch.Tensor,
  135. attention_mask: torch.Tensor | None,
  136. scaling: float,
  137. dropout: float = 0.0,
  138. **kwargs: Unpack[TransformersKwargs],
  139. ):
  140. key_states = repeat_kv(key, module.num_key_value_groups)
  141. value_states = repeat_kv(value, module.num_key_value_groups)
  142. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  143. if attention_mask is not None:
  144. attn_weights = attn_weights + attention_mask
  145. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  146. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  147. attn_output = torch.matmul(attn_weights, value_states)
  148. attn_output = attn_output.transpose(1, 2).contiguous()
  149. return attn_output, attn_weights
  150. class DbrxAttention(nn.Module):
  151. """Modular DBRX attention component that can be reused across different model architectures."""
  152. def __init__(
  153. self,
  154. config,
  155. layer_idx: int | None = None,
  156. **kwargs,
  157. ):
  158. super().__init__()
  159. self.config = config
  160. self.hidden_size = config.d_model
  161. self.num_heads = config.n_heads
  162. self.head_dim = self.hidden_size // self.num_heads
  163. self.max_position_embeddings = config.max_seq_len
  164. self.layer_idx = layer_idx
  165. attn_config = config.attn_config
  166. self.attention_dropout = attn_config.attn_pdrop
  167. self.clip_qkv = attn_config.clip_qkv
  168. self.num_key_value_heads = attn_config.kv_n_heads
  169. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  170. self.scaling = self.head_dim**-0.5
  171. self.rope_theta = attn_config.rope_theta
  172. self.is_causal = True
  173. self.Wqkv = nn.Linear(
  174. self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False
  175. )
  176. self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  177. def forward(
  178. self,
  179. hidden_states: torch.Tensor,
  180. attention_mask: torch.Tensor | None = None,
  181. position_embeddings: torch.LongTensor | None = None,
  182. past_key_values: Cache | None = None,
  183. **kwargs,
  184. ) -> tuple[torch.Tensor, torch.Tensor]:
  185. input_shape = hidden_states.shape[:-1]
  186. hidden_shape = (*input_shape, -1, self.head_dim)
  187. qkv_states = self.Wqkv(hidden_states)
  188. min_val = -self.clip_qkv if self.clip_qkv is not None else None
  189. qkv_states = qkv_states.clamp(min=min_val, max=self.clip_qkv)
  190. query_states, key_states, value_states = qkv_states.split(
  191. [
  192. self.hidden_size,
  193. self.num_key_value_heads * self.head_dim,
  194. self.num_key_value_heads * self.head_dim,
  195. ],
  196. dim=2,
  197. )
  198. query_states = query_states.view(hidden_shape).transpose(1, 2)
  199. key_states = key_states.view(hidden_shape).transpose(1, 2)
  200. value_states = value_states.view(hidden_shape).transpose(1, 2)
  201. cos, sin = position_embeddings
  202. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  203. if past_key_values is not None:
  204. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  205. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  206. self.config._attn_implementation, eager_attention_forward
  207. )
  208. attn_output, attn_weights = attention_interface(
  209. self,
  210. query_states,
  211. key_states,
  212. value_states,
  213. attention_mask,
  214. dropout=0.0 if not self.training else self.attention_dropout,
  215. scaling=self.scaling,
  216. **kwargs,
  217. )
  218. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  219. attn_output = self.out_proj(attn_output)
  220. return attn_output, attn_weights
  221. class DbrxExpertGLU(nn.Module):
  222. def __init__(self, config):
  223. super().__init__()
  224. self.hidden_size = config.hidden_size
  225. self.ffn_hidden_size = config.ffn_hidden_size
  226. self.moe_num_experts = config.moe_num_experts
  227. self.w1 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size))
  228. self.v1 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size))
  229. self.w2 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size))
  230. act_fn_name = config.ffn_act_fn.get("name", "silu")
  231. self.activation_fn = ACT2FN[act_fn_name]
  232. def forward(
  233. self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor
  234. ) -> torch.Tensor:
  235. gate_proj = x.matmul(expert_w1)
  236. up_proj = x.matmul(expert_v1)
  237. gate_proj = self.activation_fn(gate_proj)
  238. intermediate_states = gate_proj * up_proj
  239. down_proj = intermediate_states.matmul(expert_w2.t())
  240. return down_proj
  241. class DbrxExperts(nn.Module):
  242. def __init__(self, config):
  243. super().__init__()
  244. self.mlp = DbrxExpertGLU(config)
  245. self.hidden_size = config.hidden_size
  246. self.ffn_hidden_size = config.ffn_hidden_size
  247. self.num_experts = config.moe_num_experts
  248. def forward(
  249. self,
  250. hidden_states: torch.Tensor,
  251. top_k_index: torch.Tensor,
  252. top_k_weights: torch.Tensor,
  253. ) -> torch.Tensor:
  254. batch_size = hidden_states.shape[0]
  255. hidden_states = hidden_states.reshape(-1, self.ffn_hidden_size)
  256. next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
  257. with torch.no_grad():
  258. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  259. expert_mask = expert_mask.permute(2, 1, 0)
  260. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  261. split_expert_shape = (-1, self.ffn_hidden_size, self.hidden_size)
  262. for expert_idx in expert_hit:
  263. expert_idx = expert_idx[0]
  264. with torch.no_grad():
  265. idx, token_idx = torch.where(expert_mask[expert_idx])
  266. v1 = self.mlp.v1.view(split_expert_shape)[expert_idx]
  267. w1 = self.mlp.w1.view(split_expert_shape)[expert_idx]
  268. w2 = self.mlp.w2.view(split_expert_shape)[expert_idx]
  269. states = self.mlp(hidden_states[token_idx], w1, v1, w2)
  270. states = states.view(-1, self.ffn_hidden_size) * top_k_weights[token_idx, idx, None]
  271. next_states.index_add_(0, token_idx, states)
  272. next_states = next_states.view(batch_size, -1, self.ffn_hidden_size)
  273. return next_states
  274. class DbrxRouter(nn.Module):
  275. def __init__(self, config):
  276. super().__init__()
  277. self.hidden_size = config.ffn_hidden_size
  278. self.moe_jitter_eps = config.moe_jitter_eps
  279. self.layer = nn.Linear(self.hidden_size, config.moe_num_experts, bias=False)
  280. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
  281. if self.training and self.moe_jitter_eps is not None:
  282. hidden_states *= torch.empty_like(hidden_states).uniform_(
  283. 1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps
  284. )
  285. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  286. router_logits = self.layer(hidden_states)
  287. return router_logits
  288. class DbrxFFN(nn.Module):
  289. """Modular DBRX MLP/FFN component with MoE support."""
  290. def __init__(self, config, **kwargs):
  291. super().__init__()
  292. self.router = DbrxRouter(config.ffn_config)
  293. self.experts = DbrxExperts(config.ffn_config)
  294. self.moe_normalize_expert_weights = config.ffn_config.moe_normalize_expert_weights
  295. self.top_k = config.ffn_config.moe_top_k
  296. def route_tokens_to_experts(self, router_logits):
  297. router_logits = torch.nn.functional.softmax(router_logits, dim=1, dtype=router_logits.dtype)
  298. router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
  299. if self.moe_normalize_expert_weights is not None:
  300. router_top_value = router_top_value / torch.norm(
  301. router_top_value, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
  302. )
  303. return router_top_value, router_indices
  304. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  305. router_logits = self.router(hidden_states)
  306. top_k_weights, top_k_index = self.route_tokens_to_experts(router_logits)
  307. output = self.experts(hidden_states, top_k_index, top_k_weights)
  308. return output
  309. class DbrxNormAttentionNorm(nn.Module):
  310. def __init__(self, config: DbrxConfig, layer_idx: int | None = None):
  311. super().__init__()
  312. self.layer_idx = layer_idx
  313. self.resid_pdrop = config.resid_pdrop
  314. self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
  315. self.attn = DbrxAttention(
  316. config=config,
  317. layer_idx=layer_idx,
  318. )
  319. self.norm_2 = nn.LayerNorm(config.d_model, bias=False)
  320. def forward(
  321. self,
  322. hidden_states: torch.Tensor,
  323. position_embeddings: torch.LongTensor,
  324. attention_mask: torch.Tensor | None = None,
  325. past_key_values: Cache | None = None,
  326. **kwargs: Any,
  327. ) -> tuple[torch.Tensor, torch.Tensor]:
  328. residual_states = hidden_states
  329. hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
  330. hidden_states, _ = self.attn(
  331. hidden_states=hidden_states,
  332. attention_mask=attention_mask,
  333. position_embeddings=position_embeddings,
  334. past_key_values=past_key_values,
  335. **kwargs,
  336. )
  337. hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
  338. hidden_states = hidden_states + residual_states
  339. residual_states = hidden_states
  340. hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
  341. return residual_states, hidden_states
  342. class DbrxBlock(GradientCheckpointingLayer):
  343. def __init__(self, config: DbrxConfig, layer_idx: int):
  344. super().__init__()
  345. self.hidden_size = config.d_model
  346. self.resid_pdrop = config.resid_pdrop
  347. self.layer_idx = layer_idx
  348. self.norm_attn_norm = DbrxNormAttentionNorm(
  349. config=config,
  350. layer_idx=layer_idx,
  351. )
  352. self.ffn = DbrxFFN(config=config)
  353. def forward(
  354. self,
  355. hidden_states: torch.Tensor,
  356. attention_mask: torch.Tensor | None = None,
  357. position_embeddings: torch.LongTensor | None = None,
  358. past_key_values: Cache | None = None,
  359. **kwargs: Any,
  360. ):
  361. resid_states, hidden_states = self.norm_attn_norm(
  362. hidden_states=hidden_states,
  363. attention_mask=attention_mask,
  364. position_embeddings=position_embeddings,
  365. past_key_values=past_key_values,
  366. **kwargs,
  367. )
  368. hidden_states = self.ffn(hidden_states)
  369. hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
  370. hidden_states = resid_states + hidden_states
  371. return hidden_states
  372. class DbrxPreTrainedModel(PreTrainedModel):
  373. config: DbrxConfig
  374. base_model_prefix = "transformer"
  375. supports_gradient_checkpointing = True
  376. _no_split_modules = ["DbrxBlock"]
  377. _skip_keys_device_placement = ["past_key_values"]
  378. _supports_flex_attn = True
  379. _supports_attention_backend = True
  380. _supports_flash_attn = True
  381. _supports_sdpa = True
  382. _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
  383. _can_record_outputs = {
  384. "hidden_states": DbrxBlock,
  385. "attentions": DbrxAttention,
  386. }
  387. @torch.no_grad()
  388. def _init_weights(self, module: nn.Module):
  389. super()._init_weights(module)
  390. std = self.config.initializer_range
  391. if isinstance(module, DbrxExpertGLU):
  392. init.normal_(module.w1, mean=0.0, std=std)
  393. init.normal_(module.v1, mean=0.0, std=std)
  394. init.normal_(module.w2, mean=0.0, std=std)
  395. @auto_docstring
  396. class DbrxModel(DbrxPreTrainedModel):
  397. """Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer.
  398. Args:
  399. config ([`DbrxConfig`]): Model configuration class with all parameters of the model.
  400. Initializing with a config file does not load the weights associated with the model, only the
  401. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  402. """
  403. def __init__(self, config: DbrxConfig):
  404. super().__init__(config)
  405. self.padding_idx = config.pad_token_id
  406. self.vocab_size = config.vocab_size
  407. self.emb_pdrop = config.emb_pdrop
  408. self.rotary_emb = DbrxRotaryEmbedding(config)
  409. self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  410. self.blocks = nn.ModuleList([DbrxBlock(config, layer_idx) for layer_idx in range(config.n_layers)])
  411. self.norm_f = nn.LayerNorm(config.d_model, bias=False)
  412. self.gradient_checkpointing = False
  413. # Initialize weights and apply final processing
  414. self.post_init()
  415. def get_input_embeddings(self) -> nn.Embedding:
  416. return self.wte
  417. def set_input_embeddings(self, value: nn.Embedding):
  418. self.wte = value
  419. @merge_with_config_defaults
  420. @capture_outputs
  421. @auto_docstring
  422. def forward(
  423. self,
  424. input_ids: torch.LongTensor | None = None,
  425. attention_mask: torch.Tensor | None = None,
  426. position_ids: torch.LongTensor | None = None,
  427. past_key_values: Cache | None = None,
  428. inputs_embeds: torch.FloatTensor | None = None,
  429. use_cache: bool | None = None,
  430. **kwargs: Unpack[TransformersKwargs],
  431. ) -> MoeModelOutputWithPast:
  432. if (input_ids is None) ^ (inputs_embeds is not None):
  433. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  434. if use_cache and past_key_values is None:
  435. past_key_values = DynamicCache(config=self.config)
  436. if inputs_embeds is None:
  437. inputs_embeds = self.wte(input_ids)
  438. if position_ids is None:
  439. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  440. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  441. position_ids = position_ids.unsqueeze(0)
  442. causal_mask = create_causal_mask(
  443. config=self.config,
  444. inputs_embeds=inputs_embeds,
  445. attention_mask=attention_mask,
  446. past_key_values=past_key_values,
  447. position_ids=position_ids,
  448. )
  449. hidden_states = inputs_embeds
  450. # create position embeddings to be shared across the decoder layers
  451. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  452. for decoder_layer in self.blocks[: self.config.num_hidden_layers]:
  453. hidden_states = decoder_layer(
  454. hidden_states,
  455. position_embeddings=position_embeddings,
  456. attention_mask=causal_mask,
  457. position_ids=position_ids,
  458. past_key_values=past_key_values,
  459. use_cache=use_cache,
  460. **kwargs,
  461. )
  462. hidden_states = self.norm_f(hidden_states)
  463. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  464. last_hidden_state=hidden_states,
  465. past_key_values=past_key_values,
  466. )
  467. def load_balancing_loss_func(
  468. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  469. num_experts: int | None = None,
  470. top_k=2,
  471. attention_mask: torch.Tensor | None = None,
  472. ) -> torch.Tensor | int:
  473. r"""
  474. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  475. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  476. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  477. experts is too unbalanced.
  478. Args:
  479. gate_logits:
  480. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  481. shape [batch_size X sequence_length, num_experts].
  482. num_experts:
  483. Number of experts
  484. top_k:
  485. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  486. parameter.
  487. attention_mask (`torch.Tensor`, *optional*):
  488. The attention_mask used in forward function
  489. shape [batch_size X sequence_length] if not None.
  490. Returns:
  491. The auxiliary loss.
  492. """
  493. if gate_logits is None or not isinstance(gate_logits, tuple):
  494. return 0
  495. if isinstance(gate_logits, tuple):
  496. compute_device = gate_logits[0].device
  497. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  498. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  499. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  500. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  501. if attention_mask is None:
  502. # Compute the percentage of tokens routed to each experts
  503. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  504. # Compute the average probability of routing to these experts
  505. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  506. else:
  507. batch_size, sequence_length = attention_mask.shape
  508. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  509. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  510. expert_attention_mask = (
  511. attention_mask[None, :, :, None, None]
  512. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  513. .reshape(-1, top_k, num_experts)
  514. .to(compute_device)
  515. )
  516. # Compute the percentage of tokens routed to each experts
  517. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  518. expert_attention_mask, dim=0
  519. )
  520. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  521. router_per_expert_attention_mask = (
  522. attention_mask[None, :, :, None]
  523. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  524. .reshape(-1, num_experts)
  525. .to(compute_device)
  526. )
  527. # Compute the average probability of routing to these experts
  528. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  529. router_per_expert_attention_mask, dim=0
  530. )
  531. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  532. return overall_loss * num_experts
  533. class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin):
  534. _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
  535. _tp_plan = {"lm_head": "colwise_gather_output"}
  536. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  537. def __init__(self, config: DbrxConfig):
  538. super().__init__(config)
  539. self.transformer = DbrxModel(config)
  540. self.vocab_size = config.vocab_size
  541. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  542. self.router_aux_loss_coef = config.ffn_config.moe_loss_weight
  543. self.num_experts = config.ffn_config.moe_num_experts
  544. self.num_experts_per_tok = config.ffn_config.moe_top_k
  545. self.post_init()
  546. def get_input_embeddings(self) -> nn.Embedding:
  547. return self.transformer.get_input_embeddings()
  548. def set_input_embeddings(self, value: nn.Embedding):
  549. self.transformer.set_input_embeddings(value)
  550. def get_output_embeddings(self) -> nn.Linear:
  551. return self.lm_head
  552. def set_output_embeddings(self, new_embeddings: nn.Linear):
  553. self.lm_head = new_embeddings
  554. def set_decoder(self, decoder: DbrxModel):
  555. self.transformer = decoder
  556. def get_decoder(self) -> DbrxModel:
  557. return self.transformer
  558. @can_return_tuple
  559. @auto_docstring
  560. def forward(
  561. self,
  562. input_ids: torch.LongTensor | None = None,
  563. attention_mask: torch.Tensor | None = None,
  564. position_ids: torch.LongTensor | None = None,
  565. past_key_values: Cache | None = None,
  566. inputs_embeds: torch.FloatTensor | None = None,
  567. labels: torch.LongTensor | None = None,
  568. use_cache: bool | None = None,
  569. output_router_logits: bool | None = None,
  570. logits_to_keep: int | torch.Tensor = 0,
  571. **kwargs: Unpack[TransformersKwargs],
  572. ) -> MoeCausalLMOutputWithPast:
  573. r"""
  574. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  575. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  576. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  577. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  578. Example:
  579. ```python
  580. >> from transformers import AutoTokenizer, DbrxForCausalLM
  581. >> model = DbrxForCausalLM.from_pretrained("transformers-community/dbrx-instruct")
  582. >> tokenizer = AutoTokenizer.from_pretrained("transformers-community/dbrx-instruct")
  583. >> prompt = "Hey, are you conscious? Can you talk to me?"
  584. >> inputs = tokenizer(prompt, return_tensors="pt")
  585. >> # Generate
  586. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  587. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  588. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  589. ```
  590. """
  591. output_router_logits = (
  592. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  593. )
  594. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  595. outputs: MoeModelOutputWithPast = self.transformer(
  596. input_ids=input_ids,
  597. attention_mask=attention_mask,
  598. position_ids=position_ids,
  599. past_key_values=past_key_values,
  600. inputs_embeds=inputs_embeds,
  601. use_cache=use_cache,
  602. output_router_logits=output_router_logits,
  603. **kwargs,
  604. )
  605. hidden_states = outputs.last_hidden_state
  606. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  607. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  608. logits = self.lm_head(hidden_states[:, slice_indices, :])
  609. loss = None
  610. if labels is not None:
  611. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  612. aux_loss = None
  613. if output_router_logits:
  614. aux_loss = load_balancing_loss_func(
  615. outputs.router_logits,
  616. self.num_experts,
  617. self.num_experts_per_tok,
  618. attention_mask,
  619. )
  620. if labels is not None:
  621. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  622. return MoeCausalLMOutputWithPast(
  623. loss=loss,
  624. aux_loss=aux_loss,
  625. logits=logits,
  626. past_key_values=outputs.past_key_values,
  627. hidden_states=outputs.hidden_states,
  628. attentions=outputs.attentions,
  629. router_logits=outputs.router_logits,
  630. )
  631. __all__ = ["DbrxForCausalLM", "DbrxModel", "DbrxPreTrainedModel"]