modular_dbrx.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. # Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Modular components for DBRX model."""
  15. from collections.abc import Callable
  16. from typing import Any
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...masking_utils import create_causal_mask
  24. from ...modeling_layers import (
  25. GradientCheckpointingLayer,
  26. )
  27. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  29. from ...processing_utils import Unpack
  30. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  31. from ...utils.generic import merge_with_config_defaults
  32. from ...utils.output_capturing import capture_outputs
  33. from ..llama.modeling_llama import (
  34. LlamaRotaryEmbedding,
  35. apply_rotary_pos_emb,
  36. eager_attention_forward,
  37. )
  38. from ..mixtral.modeling_mixtral import load_balancing_loss_func
  39. from .configuration_dbrx import DbrxConfig
  40. class DbrxRotaryEmbedding(LlamaRotaryEmbedding):
  41. pass
  42. class DbrxAttention(nn.Module):
  43. """Modular DBRX attention component that can be reused across different model architectures."""
  44. def __init__(
  45. self,
  46. config,
  47. layer_idx: int | None = None,
  48. **kwargs,
  49. ):
  50. super().__init__()
  51. self.config = config
  52. self.hidden_size = config.d_model
  53. self.num_heads = config.n_heads
  54. self.head_dim = self.hidden_size // self.num_heads
  55. self.max_position_embeddings = config.max_seq_len
  56. self.layer_idx = layer_idx
  57. attn_config = config.attn_config
  58. self.attention_dropout = attn_config.attn_pdrop
  59. self.clip_qkv = attn_config.clip_qkv
  60. self.num_key_value_heads = attn_config.kv_n_heads
  61. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  62. self.scaling = self.head_dim**-0.5
  63. self.rope_theta = attn_config.rope_theta
  64. self.is_causal = True
  65. self.Wqkv = nn.Linear(
  66. self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False
  67. )
  68. self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  69. def forward(
  70. self,
  71. hidden_states: torch.Tensor,
  72. attention_mask: torch.Tensor | None = None,
  73. position_embeddings: torch.LongTensor | None = None,
  74. past_key_values: Cache | None = None,
  75. **kwargs,
  76. ) -> tuple[torch.Tensor, torch.Tensor]:
  77. input_shape = hidden_states.shape[:-1]
  78. hidden_shape = (*input_shape, -1, self.head_dim)
  79. qkv_states = self.Wqkv(hidden_states)
  80. min_val = -self.clip_qkv if self.clip_qkv is not None else None
  81. qkv_states = qkv_states.clamp(min=min_val, max=self.clip_qkv)
  82. query_states, key_states, value_states = qkv_states.split(
  83. [
  84. self.hidden_size,
  85. self.num_key_value_heads * self.head_dim,
  86. self.num_key_value_heads * self.head_dim,
  87. ],
  88. dim=2,
  89. )
  90. query_states = query_states.view(hidden_shape).transpose(1, 2)
  91. key_states = key_states.view(hidden_shape).transpose(1, 2)
  92. value_states = value_states.view(hidden_shape).transpose(1, 2)
  93. cos, sin = position_embeddings
  94. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  95. if past_key_values is not None:
  96. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  97. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  98. self.config._attn_implementation, eager_attention_forward
  99. )
  100. attn_output, attn_weights = attention_interface(
  101. self,
  102. query_states,
  103. key_states,
  104. value_states,
  105. attention_mask,
  106. dropout=0.0 if not self.training else self.attention_dropout,
  107. scaling=self.scaling,
  108. **kwargs,
  109. )
  110. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  111. attn_output = self.out_proj(attn_output)
  112. return attn_output, attn_weights
  113. class DbrxExpertGLU(nn.Module):
  114. def __init__(self, config):
  115. super().__init__()
  116. self.hidden_size = config.hidden_size
  117. self.ffn_hidden_size = config.ffn_hidden_size
  118. self.moe_num_experts = config.moe_num_experts
  119. self.w1 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size))
  120. self.v1 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size))
  121. self.w2 = nn.Parameter(torch.empty(self.moe_num_experts * self.ffn_hidden_size, self.hidden_size))
  122. act_fn_name = config.ffn_act_fn.get("name", "silu")
  123. self.activation_fn = ACT2FN[act_fn_name]
  124. def forward(
  125. self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor
  126. ) -> torch.Tensor:
  127. gate_proj = x.matmul(expert_w1)
  128. up_proj = x.matmul(expert_v1)
  129. gate_proj = self.activation_fn(gate_proj)
  130. intermediate_states = gate_proj * up_proj
  131. down_proj = intermediate_states.matmul(expert_w2.t())
  132. return down_proj
  133. class DbrxExperts(nn.Module):
  134. def __init__(self, config):
  135. super().__init__()
  136. self.mlp = DbrxExpertGLU(config)
  137. self.hidden_size = config.hidden_size
  138. self.ffn_hidden_size = config.ffn_hidden_size
  139. self.num_experts = config.moe_num_experts
  140. def forward(
  141. self,
  142. hidden_states: torch.Tensor,
  143. top_k_index: torch.Tensor,
  144. top_k_weights: torch.Tensor,
  145. ) -> torch.Tensor:
  146. batch_size = hidden_states.shape[0]
  147. hidden_states = hidden_states.reshape(-1, self.ffn_hidden_size)
  148. next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
  149. with torch.no_grad():
  150. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  151. expert_mask = expert_mask.permute(2, 1, 0)
  152. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  153. split_expert_shape = (-1, self.ffn_hidden_size, self.hidden_size)
  154. for expert_idx in expert_hit:
  155. expert_idx = expert_idx[0]
  156. with torch.no_grad():
  157. idx, token_idx = torch.where(expert_mask[expert_idx])
  158. v1 = self.mlp.v1.view(split_expert_shape)[expert_idx]
  159. w1 = self.mlp.w1.view(split_expert_shape)[expert_idx]
  160. w2 = self.mlp.w2.view(split_expert_shape)[expert_idx]
  161. states = self.mlp(hidden_states[token_idx], w1, v1, w2)
  162. states = states.view(-1, self.ffn_hidden_size) * top_k_weights[token_idx, idx, None]
  163. next_states.index_add_(0, token_idx, states)
  164. next_states = next_states.view(batch_size, -1, self.ffn_hidden_size)
  165. return next_states
  166. class DbrxRouter(nn.Module):
  167. def __init__(self, config):
  168. super().__init__()
  169. self.hidden_size = config.ffn_hidden_size
  170. self.moe_jitter_eps = config.moe_jitter_eps
  171. self.layer = nn.Linear(self.hidden_size, config.moe_num_experts, bias=False)
  172. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
  173. if self.training and self.moe_jitter_eps is not None:
  174. hidden_states *= torch.empty_like(hidden_states).uniform_(
  175. 1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps
  176. )
  177. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  178. router_logits = self.layer(hidden_states)
  179. return router_logits
  180. class DbrxFFN(nn.Module):
  181. """Modular DBRX MLP/FFN component with MoE support."""
  182. def __init__(self, config, **kwargs):
  183. super().__init__()
  184. self.router = DbrxRouter(config.ffn_config)
  185. self.experts = DbrxExperts(config.ffn_config)
  186. self.moe_normalize_expert_weights = config.ffn_config.moe_normalize_expert_weights
  187. self.top_k = config.ffn_config.moe_top_k
  188. def route_tokens_to_experts(self, router_logits):
  189. router_logits = torch.nn.functional.softmax(router_logits, dim=1, dtype=router_logits.dtype)
  190. router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
  191. if self.moe_normalize_expert_weights is not None:
  192. router_top_value = router_top_value / torch.norm(
  193. router_top_value, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
  194. )
  195. return router_top_value, router_indices
  196. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  197. router_logits = self.router(hidden_states)
  198. top_k_weights, top_k_index = self.route_tokens_to_experts(router_logits)
  199. output = self.experts(hidden_states, top_k_index, top_k_weights)
  200. return output
  201. class DbrxNormAttentionNorm(nn.Module):
  202. def __init__(self, config: DbrxConfig, layer_idx: int | None = None):
  203. super().__init__()
  204. self.layer_idx = layer_idx
  205. self.resid_pdrop = config.resid_pdrop
  206. self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
  207. self.attn = DbrxAttention(
  208. config=config,
  209. layer_idx=layer_idx,
  210. )
  211. self.norm_2 = nn.LayerNorm(config.d_model, bias=False)
  212. def forward(
  213. self,
  214. hidden_states: torch.Tensor,
  215. position_embeddings: torch.LongTensor,
  216. attention_mask: torch.Tensor | None = None,
  217. past_key_values: Cache | None = None,
  218. **kwargs: Any,
  219. ) -> tuple[torch.Tensor, torch.Tensor]:
  220. residual_states = hidden_states
  221. hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
  222. hidden_states, _ = self.attn(
  223. hidden_states=hidden_states,
  224. attention_mask=attention_mask,
  225. position_embeddings=position_embeddings,
  226. past_key_values=past_key_values,
  227. **kwargs,
  228. )
  229. hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
  230. hidden_states = hidden_states + residual_states
  231. residual_states = hidden_states
  232. hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
  233. return residual_states, hidden_states
  234. class DbrxBlock(GradientCheckpointingLayer):
  235. def __init__(self, config: DbrxConfig, layer_idx: int):
  236. super().__init__()
  237. self.hidden_size = config.d_model
  238. self.resid_pdrop = config.resid_pdrop
  239. self.layer_idx = layer_idx
  240. self.norm_attn_norm = DbrxNormAttentionNorm(
  241. config=config,
  242. layer_idx=layer_idx,
  243. )
  244. self.ffn = DbrxFFN(config=config)
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. attention_mask: torch.Tensor | None = None,
  249. position_embeddings: torch.LongTensor | None = None,
  250. past_key_values: Cache | None = None,
  251. **kwargs: Any,
  252. ):
  253. resid_states, hidden_states = self.norm_attn_norm(
  254. hidden_states=hidden_states,
  255. attention_mask=attention_mask,
  256. position_embeddings=position_embeddings,
  257. past_key_values=past_key_values,
  258. **kwargs,
  259. )
  260. hidden_states = self.ffn(hidden_states)
  261. hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
  262. hidden_states = resid_states + hidden_states
  263. return hidden_states
  264. class DbrxPreTrainedModel(PreTrainedModel):
  265. config: DbrxConfig
  266. base_model_prefix = "transformer"
  267. supports_gradient_checkpointing = True
  268. _no_split_modules = ["DbrxBlock"]
  269. _skip_keys_device_placement = ["past_key_values"]
  270. _supports_flex_attn = True
  271. _supports_attention_backend = True
  272. _supports_flash_attn = True
  273. _supports_sdpa = True
  274. _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
  275. _can_record_outputs = {
  276. "hidden_states": DbrxBlock,
  277. "attentions": DbrxAttention,
  278. }
  279. @torch.no_grad()
  280. def _init_weights(self, module: nn.Module):
  281. super()._init_weights(module)
  282. std = self.config.initializer_range
  283. if isinstance(module, DbrxExpertGLU):
  284. init.normal_(module.w1, mean=0.0, std=std)
  285. init.normal_(module.v1, mean=0.0, std=std)
  286. init.normal_(module.w2, mean=0.0, std=std)
  287. @auto_docstring
  288. class DbrxModel(DbrxPreTrainedModel):
  289. """Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer.
  290. Args:
  291. config ([`DbrxConfig`]): Model configuration class with all parameters of the model.
  292. Initializing with a config file does not load the weights associated with the model, only the
  293. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  294. """
  295. def __init__(self, config: DbrxConfig):
  296. super().__init__(config)
  297. self.padding_idx = config.pad_token_id
  298. self.vocab_size = config.vocab_size
  299. self.emb_pdrop = config.emb_pdrop
  300. self.rotary_emb = DbrxRotaryEmbedding(config)
  301. self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  302. self.blocks = nn.ModuleList([DbrxBlock(config, layer_idx) for layer_idx in range(config.n_layers)])
  303. self.norm_f = nn.LayerNorm(config.d_model, bias=False)
  304. self.gradient_checkpointing = False
  305. # Initialize weights and apply final processing
  306. self.post_init()
  307. def get_input_embeddings(self) -> nn.Embedding:
  308. return self.wte
  309. def set_input_embeddings(self, value: nn.Embedding):
  310. self.wte = value
  311. @merge_with_config_defaults
  312. @capture_outputs
  313. @auto_docstring
  314. def forward(
  315. self,
  316. input_ids: torch.LongTensor | None = None,
  317. attention_mask: torch.Tensor | None = None,
  318. position_ids: torch.LongTensor | None = None,
  319. past_key_values: Cache | None = None,
  320. inputs_embeds: torch.FloatTensor | None = None,
  321. use_cache: bool | None = None,
  322. **kwargs: Unpack[TransformersKwargs],
  323. ) -> MoeModelOutputWithPast:
  324. if (input_ids is None) ^ (inputs_embeds is not None):
  325. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  326. if use_cache and past_key_values is None:
  327. past_key_values = DynamicCache(config=self.config)
  328. if inputs_embeds is None:
  329. inputs_embeds = self.wte(input_ids)
  330. if position_ids is None:
  331. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  332. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  333. position_ids = position_ids.unsqueeze(0)
  334. causal_mask = create_causal_mask(
  335. config=self.config,
  336. inputs_embeds=inputs_embeds,
  337. attention_mask=attention_mask,
  338. past_key_values=past_key_values,
  339. position_ids=position_ids,
  340. )
  341. hidden_states = inputs_embeds
  342. # create position embeddings to be shared across the decoder layers
  343. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  344. for decoder_layer in self.blocks[: self.config.num_hidden_layers]:
  345. hidden_states = decoder_layer(
  346. hidden_states,
  347. position_embeddings=position_embeddings,
  348. attention_mask=causal_mask,
  349. position_ids=position_ids,
  350. past_key_values=past_key_values,
  351. use_cache=use_cache,
  352. **kwargs,
  353. )
  354. hidden_states = self.norm_f(hidden_states)
  355. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  356. last_hidden_state=hidden_states,
  357. past_key_values=past_key_values,
  358. )
  359. class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin):
  360. _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
  361. _tp_plan = {"lm_head": "colwise_gather_output"}
  362. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  363. def __init__(self, config: DbrxConfig):
  364. super().__init__(config)
  365. self.transformer = DbrxModel(config)
  366. self.vocab_size = config.vocab_size
  367. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  368. self.router_aux_loss_coef = config.ffn_config.moe_loss_weight
  369. self.num_experts = config.ffn_config.moe_num_experts
  370. self.num_experts_per_tok = config.ffn_config.moe_top_k
  371. self.post_init()
  372. def get_input_embeddings(self) -> nn.Embedding:
  373. return self.transformer.get_input_embeddings()
  374. def set_input_embeddings(self, value: nn.Embedding):
  375. self.transformer.set_input_embeddings(value)
  376. def get_output_embeddings(self) -> nn.Linear:
  377. return self.lm_head
  378. def set_output_embeddings(self, new_embeddings: nn.Linear):
  379. self.lm_head = new_embeddings
  380. def set_decoder(self, decoder: DbrxModel):
  381. self.transformer = decoder
  382. def get_decoder(self) -> DbrxModel:
  383. return self.transformer
  384. @can_return_tuple
  385. @auto_docstring
  386. def forward(
  387. self,
  388. input_ids: torch.LongTensor | None = None,
  389. attention_mask: torch.Tensor | None = None,
  390. position_ids: torch.LongTensor | None = None,
  391. past_key_values: Cache | None = None,
  392. inputs_embeds: torch.FloatTensor | None = None,
  393. labels: torch.LongTensor | None = None,
  394. use_cache: bool | None = None,
  395. output_router_logits: bool | None = None,
  396. logits_to_keep: int | torch.Tensor = 0,
  397. **kwargs: Unpack[TransformersKwargs],
  398. ) -> MoeCausalLMOutputWithPast:
  399. r"""
  400. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  401. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  402. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  403. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  404. Example:
  405. ```python
  406. >> from transformers import AutoTokenizer, DbrxForCausalLM
  407. >> model = DbrxForCausalLM.from_pretrained("transformers-community/dbrx-instruct")
  408. >> tokenizer = AutoTokenizer.from_pretrained("transformers-community/dbrx-instruct")
  409. >> prompt = "Hey, are you conscious? Can you talk to me?"
  410. >> inputs = tokenizer(prompt, return_tensors="pt")
  411. >> # Generate
  412. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  413. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  414. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  415. ```
  416. """
  417. output_router_logits = (
  418. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  419. )
  420. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  421. outputs: MoeModelOutputWithPast = self.transformer(
  422. input_ids=input_ids,
  423. attention_mask=attention_mask,
  424. position_ids=position_ids,
  425. past_key_values=past_key_values,
  426. inputs_embeds=inputs_embeds,
  427. use_cache=use_cache,
  428. output_router_logits=output_router_logits,
  429. **kwargs,
  430. )
  431. hidden_states = outputs.last_hidden_state
  432. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  433. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  434. logits = self.lm_head(hidden_states[:, slice_indices, :])
  435. loss = None
  436. if labels is not None:
  437. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  438. aux_loss = None
  439. if output_router_logits:
  440. aux_loss = load_balancing_loss_func(
  441. outputs.router_logits,
  442. self.num_experts,
  443. self.num_experts_per_tok,
  444. attention_mask,
  445. )
  446. if labels is not None:
  447. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  448. return MoeCausalLMOutputWithPast(
  449. loss=loss,
  450. aux_loss=aux_loss,
  451. logits=logits,
  452. past_key_values=outputs.past_key_values,
  453. hidden_states=outputs.hidden_states,
  454. attentions=outputs.attentions,
  455. router_logits=outputs.router_logits,
  456. )
  457. __all__ = ["DbrxForCausalLM", "DbrxModel", "DbrxPreTrainedModel"]