modular_doge.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. # Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # The Doge family of small language models is trained by SmallDoge Team.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch Doge model."""
  17. import math
  18. from collections.abc import Callable
  19. from typing import Union
  20. import torch
  21. import torch.nn.functional as F
  22. from huggingface_hub.dataclasses import strict
  23. from torch import nn
  24. from ... import initialization as init
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache
  27. from ...configuration_utils import PreTrainedConfig
  28. from ...integrations.flex_attention import compile_friendly_flex_attention
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  31. from ...modeling_rope_utils import RopeParameters
  32. from ...modeling_utils import AttentionInterface, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging
  35. from ...utils.output_capturing import OutputRecorder
  36. from ..llama.modeling_llama import (
  37. LlamaForSequenceClassification,
  38. LlamaMLP,
  39. LlamaPreTrainedModel,
  40. LlamaRMSNorm,
  41. LlamaRotaryEmbedding,
  42. apply_rotary_pos_emb,
  43. eager_attention_forward,
  44. repeat_kv,
  45. )
  46. from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
  47. logger = logging.get_logger(__name__)
  48. if is_torch_flex_attn_available():
  49. from torch.nn.attention.flex_attention import BlockMask
  50. @auto_docstring(checkpoint="SmallDoge/Doge-320M")
  51. @strict
  52. class DogeConfig(PreTrainedConfig):
  53. r"""
  54. keep_window_size (`int`, *optional*, defaults to 2048):
  55. The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
  56. is_moe (`bool`, *optional*, defaults to `False`):
  57. Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize.
  58. ```python
  59. >>> from transformers import DogeConfig, DogeModel
  60. >>> # Initializing a Doge-320M style configuration
  61. >>> configuration = DogeConfig()
  62. >>> # Initializing a model from the Doge-320M style configuration
  63. >>> model = DogeModel(configuration)
  64. >>> # Accessing the model configuration
  65. >>> configuration = model.config
  66. ```"""
  67. model_type = "doge"
  68. keys_to_ignore_at_inference = ["past_key_values"]
  69. # Default tensor parallel plan for base model `DogeModel`
  70. base_model_tp_plan = {
  71. "layers.*.self_attn.q_proj": "colwise",
  72. "layers.*.self_attn.k_proj": "colwise",
  73. "layers.*.self_attn.v_proj": "colwise",
  74. "layers.*.self_attn.dt_proj": "rowwise",
  75. "layers.*.self_attn.o_proj": "rowwise",
  76. "layers.*.mlp.gate_proj": "colwise",
  77. "layers.*.mlp.up_proj": "colwise",
  78. "layers.*.mlp.down_proj": "rowwise",
  79. "layers.*.mlp.router_gate": "colwise_gather_output",
  80. "layers.*.mlp.down_embed": "rowwise_split_input",
  81. "layers.*.mlp.up_embed": "rowwise_split_input",
  82. }
  83. base_model_pp_plan = {
  84. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  85. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  86. "norm": (["hidden_states"], ["hidden_states"]),
  87. }
  88. vocab_size: int = 32768
  89. hidden_size: int = 1024
  90. intermediate_size: int = 2048
  91. num_hidden_layers: int = 32
  92. hidden_dropout: float | int = 0.0
  93. hidden_act: str = "silu"
  94. initializer_range: float = 0.02
  95. rms_norm_eps: float = 1e-06
  96. use_cache: bool = True
  97. tie_word_embeddings: bool = False
  98. max_position_embeddings: int = 2048
  99. rope_parameters: RopeParameters | dict | None = None
  100. num_attention_heads: int = 8
  101. num_key_value_heads: int | None = None
  102. attention_bias: bool = False
  103. attention_dropout: float | None = 0.0
  104. mlp_bias: bool = False
  105. sliding_window: int | None = None
  106. keep_window_size: int = 2048
  107. is_moe: bool = False
  108. num_experts: int = 16384
  109. num_experts_per_tok: int = 64
  110. norm_topk_prob: bool = False
  111. output_router_logits: bool = False
  112. router_aux_loss_coef: float = 0.001
  113. pad_token_id: int | None = None
  114. bos_token_id: int | None = None
  115. eos_token_id: int | list[int] | None = None
  116. def __post_init__(self, **kwargs):
  117. # for backward compatibility
  118. if self.num_key_value_heads is None:
  119. self.num_key_value_heads = self.num_attention_heads
  120. super().__post_init__(**kwargs)
  121. class DogeRMSNorm(LlamaRMSNorm):
  122. pass
  123. class DogeRotaryEmbedding(LlamaRotaryEmbedding):
  124. pass
  125. def flex_attention_forward(
  126. module: nn.Module,
  127. query: torch.Tensor,
  128. key: torch.Tensor,
  129. value: torch.Tensor,
  130. attention_mask: Union[torch.Tensor, "BlockMask"],
  131. scaling: float | None = None,
  132. softcap: float | None = None,
  133. **kwargs,
  134. ) -> tuple[torch.Tensor, torch.Tensor]:
  135. block_mask = None
  136. causal_mask = None
  137. if isinstance(attention_mask, BlockMask):
  138. block_mask = attention_mask
  139. else:
  140. causal_mask = attention_mask
  141. if causal_mask is not None:
  142. causal_mask = causal_mask[:, :, :, : key.shape[-2]]
  143. def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
  144. if softcap is not None:
  145. score = softcap * torch.tanh(score / softcap)
  146. if causal_mask is not None:
  147. score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx]
  148. return score
  149. attn_output, attention_weights = compile_friendly_flex_attention(
  150. query,
  151. key,
  152. value,
  153. score_mod=score_mod,
  154. block_mask=block_mask,
  155. enable_gqa=True,
  156. scale=scaling,
  157. # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
  158. # For simplification, we thus always return it as no additional computations are introduced.
  159. return_lse=True,
  160. )
  161. # lse is returned in float32
  162. attention_weights = attention_weights.to(value.dtype)
  163. attn_output = attn_output.transpose(1, 2).contiguous()
  164. return attn_output, attention_weights
  165. ALL_ATTENTION_FUNCTIONS = AttentionInterface()
  166. ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward
  167. class DogeAttention(nn.Module):
  168. def __init__(self, config: DogeConfig, layer_idx: int | None = None):
  169. super().__init__()
  170. self.config = config
  171. self.layer_idx = layer_idx
  172. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  173. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  174. self.scaling = self.head_dim**-0.5
  175. self.attention_dropout = config.attention_dropout
  176. self.keep_window_size = config.keep_window_size
  177. self.q_proj = nn.Linear(
  178. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  179. )
  180. self.k_proj = nn.Linear(
  181. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  182. )
  183. self.v_proj = nn.Linear(
  184. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  185. )
  186. # dynamic mask for the QK^T attention weights matrix
  187. self.A = nn.Parameter(torch.zeros(config.num_key_value_heads))
  188. self.dt_proj = nn.Linear(
  189. config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias
  190. )
  191. self.o_proj = nn.Linear(
  192. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  193. )
  194. self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  195. self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  196. def forward(
  197. self,
  198. hidden_states: torch.Tensor,
  199. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  200. attention_mask: torch.Tensor | None = None,
  201. past_key_values: Cache | None = None,
  202. **kwargs,
  203. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  204. input_shape = hidden_states.shape[:-1]
  205. hidden_shape = (*input_shape, -1, self.head_dim)
  206. query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  207. key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  208. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  209. cos, sin = position_embeddings
  210. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  211. if past_key_values is not None:
  212. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  213. # calculate dynamic mask from value_states
  214. dt_states = self.dt_proj(
  215. value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
  216. )
  217. dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
  218. attn_mask = self.prepare_dynamic_mask(
  219. hidden_states=hidden_states,
  220. dt_states=dt_states,
  221. keep_window_size=self.keep_window_size,
  222. attention_mask=attention_mask,
  223. )
  224. attn_mask = repeat_kv(attn_mask, self.num_key_value_groups)
  225. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  226. self.config._attn_implementation, eager_attention_forward
  227. )
  228. attn_output, attn_weights = attention_interface(
  229. self,
  230. query_states,
  231. key_states,
  232. value_states,
  233. attention_mask=attn_mask,
  234. dropout=0.0 if not self.training else self.attention_dropout,
  235. scaling=self.scaling,
  236. **kwargs,
  237. )
  238. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  239. attn_output = self.o_proj(attn_output)
  240. return attn_output, attn_weights
  241. def prepare_dynamic_mask(
  242. self,
  243. hidden_states: torch.Tensor,
  244. dt_states: torch.Tensor,
  245. keep_window_size: int = 2048,
  246. attention_mask: torch.Tensor | None = None,
  247. ):
  248. """
  249. The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention.
  250. Combine `dt_states` with `attention_mask` to generate the final `attn_mask`.
  251. Args:
  252. hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
  253. dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`.
  254. keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
  255. attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
  256. """
  257. min_dtype = torch.finfo(hidden_states.dtype).min
  258. dtype = hidden_states.dtype
  259. attn_mask = dt_states[:, :, None, :].expand(
  260. -1, -1, hidden_states.shape[1], -1
  261. ) # [batch_size, num_heads, query_len, key_len]
  262. if attention_mask is not None and not isinstance(attention_mask, BlockMask):
  263. if attention_mask.dtype == torch.bool:
  264. dtype = hidden_states.dtype
  265. attention_mask = torch.where(
  266. attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype
  267. )
  268. attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype)
  269. if attn_mask.shape[-1] > keep_window_size:
  270. active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device)
  271. topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, largest=True, sorted=False).indices
  272. active_mask = active_mask.scatter(-1, topk_indices, 1.0)
  273. attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype)
  274. return attn_mask
  275. class DogeMLP(LlamaMLP):
  276. pass
  277. class DogeCDMoE(nn.Module):
  278. def __init__(self, config: DogeConfig):
  279. super().__init__()
  280. self.hidden_size = config.hidden_size
  281. self.intermediate_size = config.intermediate_size
  282. self.act_fn = ACT2FN[config.hidden_act]
  283. self.num_experts = config.num_experts
  284. self.num_keys = math.floor(math.sqrt(self.num_experts))
  285. self.top_k = config.num_experts_per_tok
  286. self.norm_topk_prob = config.norm_topk_prob
  287. # shared expert
  288. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  289. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  290. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  291. # router gate for retrieval experts
  292. self.router_gate = nn.Linear(self.hidden_size, self.num_keys * 2, bias=False)
  293. # routed experts
  294. self.down_embed = nn.Embedding(self.num_experts, self.hidden_size)
  295. self.up_embed = nn.Embedding(self.num_experts, self.hidden_size)
  296. def forward(
  297. self,
  298. hidden_states: torch.Tensor,
  299. **kwargs,
  300. ) -> torch.Tensor:
  301. bsz, seq_len, _ = hidden_states.shape
  302. # get routing logits with router gate
  303. router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
  304. # get experts with the highest routing logits
  305. (scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1)
  306. all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
  307. all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
  308. all_scores = all_scores.view(*all_scores.shape[:-2], -1)
  309. all_indices = all_indices.view(*all_indices.shape[:-2], -1)
  310. scores, position_indices = all_scores.topk(self.top_k, dim=-1)
  311. indices = all_indices.gather(-1, position_indices)
  312. routing_weights = F.softmax(scores, dim=-1)
  313. if self.norm_topk_prob:
  314. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  315. # mix routed experts states with shared expert states
  316. down_embed = self.down_embed(indices)
  317. up_embed = self.up_embed(indices)
  318. experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1)
  319. experts_weights = self.act_fn(experts_weights) * routing_weights
  320. experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
  321. hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
  322. hidden_states = hidden_states + experts_states
  323. return hidden_states, router_logits
  324. class DogeDecoderLayer(GradientCheckpointingLayer):
  325. def __init__(self, config: DogeConfig, layer_idx: int | None = None):
  326. super().__init__()
  327. self.hidden_dropout = config.hidden_dropout
  328. self.input_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  329. self.self_attn = DogeAttention(config=config, layer_idx=layer_idx)
  330. self.input_residual = nn.Parameter(torch.ones(config.hidden_size))
  331. self.post_attention_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  332. self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config)
  333. self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size))
  334. def forward(
  335. self,
  336. hidden_states: torch.Tensor,
  337. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  338. attention_mask: torch.Tensor | None = None,
  339. position_ids: torch.LongTensor | None = None,
  340. past_key_values: Cache | None = None,
  341. use_cache: bool | None = False,
  342. **kwargs: Unpack[TransformersKwargs],
  343. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  344. # sequence transformation
  345. residual = hidden_states
  346. hidden_states = self.input_layernorm(hidden_states)
  347. hidden_states, self_attn_weights = self.self_attn(
  348. hidden_states=hidden_states,
  349. position_embeddings=position_embeddings,
  350. attention_mask=attention_mask,
  351. position_ids=position_ids,
  352. past_key_values=past_key_values,
  353. use_cache=use_cache,
  354. **kwargs,
  355. )
  356. hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
  357. hidden_states = self.input_residual * residual + hidden_states
  358. # state transformation
  359. residual = hidden_states
  360. hidden_states = self.post_attention_layernorm(hidden_states)
  361. hidden_states = self.mlp(hidden_states)
  362. hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
  363. hidden_states = self.post_attention_residual * residual + hidden_states
  364. return hidden_states
  365. class DogePreTrainedModel(LlamaPreTrainedModel):
  366. _supports_flash_attn = False
  367. _can_compile_fullgraph = False
  368. _can_record_outputs = {
  369. "router_logits": OutputRecorder(DogeCDMoE, index=1),
  370. "hidden_states": DogeDecoderLayer,
  371. "attentions": DogeAttention,
  372. }
  373. @torch.no_grad()
  374. def _init_weights(self, module):
  375. """Initialize the weights"""
  376. PreTrainedModel._init_weights(self, module)
  377. if isinstance(module, DogeAttention):
  378. if hasattr(module, "A"):
  379. init.zeros_(module.A)
  380. elif isinstance(module, DogeDecoderLayer):
  381. if hasattr(module, "input_residual"):
  382. init.ones_(module.input_residual)
  383. if hasattr(module, "post_attention_residual"):
  384. init.ones_(module.post_attention_residual)
  385. class DogeModel(MixtralModel):
  386. pass
  387. def load_balancing_loss_func(
  388. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  389. num_experts: int | None = None,
  390. num_keys: int | None = None,
  391. top_k: int = 2,
  392. attention_mask: torch.Tensor | None = None,
  393. ) -> torch.Tensor | int:
  394. r"""
  395. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  396. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  397. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  398. experts is too unbalanced.
  399. Args:
  400. gate_logits:
  401. Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of
  402. shape [2, batch_size * sequence_length, num_keys].
  403. num_experts:
  404. Number of experts
  405. num_keys:
  406. Number of keys
  407. top_k:
  408. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  409. parameter.
  410. attention_mask (`torch.Tensor`, *optional*):
  411. The attention_mask used in forward function
  412. shape [batch_size X sequence_length] if not None.
  413. Returns:
  414. The auxiliary loss.
  415. """
  416. if gate_logits is None or not isinstance(gate_logits, tuple):
  417. return 0
  418. compute_dtype = gate_logits[0].dtype
  419. compute_device = gate_logits[0].device
  420. all_expert_indices = []
  421. all_routing_weights = []
  422. for layer_gate_logits in gate_logits:
  423. layer_gate_logits = layer_gate_logits.to(compute_device)
  424. (scores_x, scores_y), (indices_x, indices_y) = layer_gate_logits.topk(num_keys, dim=-1)
  425. all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
  426. all_indices = indices_x.unsqueeze(-1) * num_keys + indices_y.unsqueeze(-2)
  427. all_scores = all_scores.view(*all_scores.shape[:-2], -1)
  428. all_indices = all_indices.view(*all_indices.shape[:-2], -1)
  429. _, position_indices = all_scores.topk(top_k, dim=-1)
  430. expert_indices = all_indices.gather(-1, position_indices)
  431. routing_weights = F.softmax(all_scores, dim=-1)
  432. all_expert_indices.append(expert_indices)
  433. all_routing_weights.append(routing_weights)
  434. all_expert_indices = torch.cat(all_expert_indices, dim=0)
  435. all_routing_weights = torch.cat(all_routing_weights, dim=0)
  436. if attention_mask is None:
  437. # Compute the percentage of tokens routed to each experts
  438. all_expert_indices = all_expert_indices.view(-1)
  439. tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
  440. pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
  441. tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / all_expert_indices.shape[0]
  442. # Compute the average probability of routing to these experts
  443. router_prob_per_expert = torch.mean(all_routing_weights, dim=0)
  444. else:
  445. batch_size, sequence_length = attention_mask.shape
  446. num_hidden_layers = len(gate_logits)
  447. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  448. expert_attention_mask = (
  449. attention_mask[None, :, :, None]
  450. .expand((num_hidden_layers, batch_size, sequence_length, top_k))
  451. .reshape(-1)
  452. .to(compute_device)
  453. )
  454. all_expert_indices = all_expert_indices.view(-1)[expert_attention_mask.bool()]
  455. # Compute the percentage of tokens routed to each experts
  456. tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
  457. pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
  458. tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / torch.sum(
  459. expert_attention_mask
  460. )
  461. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  462. router_per_expert_attention_mask = (
  463. attention_mask[None, :, :, None]
  464. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  465. .reshape(-1, num_experts)
  466. .to(compute_device)
  467. )
  468. # Compute the average probability of routing to these experts
  469. router_prob_per_expert = torch.sum(all_routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  470. router_per_expert_attention_mask, dim=0
  471. )
  472. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
  473. return overall_loss * num_experts
  474. class DogeForCausalLM(MixtralForCausalLM):
  475. def __init__(self, config):
  476. super().__init__(config)
  477. self.model = DogeModel(config)
  478. self.num_experts = config.num_experts
  479. def forward(
  480. self,
  481. input_ids: torch.LongTensor | None = None,
  482. attention_mask: torch.Tensor | None = None,
  483. position_ids: torch.LongTensor | None = None,
  484. past_key_values: Cache | None = None,
  485. inputs_embeds: torch.FloatTensor | None = None,
  486. labels: torch.LongTensor | None = None,
  487. use_cache: bool | None = None,
  488. logits_to_keep: int | torch.Tensor = 0,
  489. output_router_logits: bool | None = None,
  490. **kwargs: Unpack[TransformersKwargs],
  491. ) -> MoeCausalLMOutputWithPast:
  492. r"""
  493. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  494. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  495. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  496. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  497. Example:
  498. ```python
  499. >>> from transformers import AutoTokenizer, DogeForCausalLM
  500. >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M")
  501. >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M")
  502. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  503. >>> inputs = tokenizer(prompt, return_tensors="pt")
  504. >>> # Generate
  505. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  506. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  507. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  508. ```"""
  509. output_router_logits = (
  510. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  511. )
  512. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  513. outputs: MoeModelOutputWithPast = self.model(
  514. input_ids=input_ids,
  515. attention_mask=attention_mask,
  516. position_ids=position_ids,
  517. past_key_values=past_key_values,
  518. inputs_embeds=inputs_embeds,
  519. use_cache=use_cache,
  520. **kwargs,
  521. )
  522. hidden_states = outputs.last_hidden_state
  523. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  524. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  525. logits = self.lm_head(hidden_states[:, slice_indices, :])
  526. loss = None
  527. if labels is not None:
  528. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  529. aux_loss = None
  530. if output_router_logits:
  531. aux_loss = load_balancing_loss_func(
  532. outputs.router_logits,
  533. self.num_experts,
  534. math.floor(math.sqrt(self.num_experts)),
  535. self.num_experts_per_tok,
  536. attention_mask,
  537. )
  538. if labels is not None:
  539. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  540. return MoeCausalLMOutputWithPast(
  541. loss=loss,
  542. aux_loss=aux_loss,
  543. logits=logits,
  544. past_key_values=outputs.past_key_values,
  545. hidden_states=outputs.hidden_states,
  546. attentions=outputs.attentions,
  547. router_logits=outputs.router_logits,
  548. )
  549. class DogeForSequenceClassification(LlamaForSequenceClassification):
  550. pass
  551. __all__ = [
  552. "DogeConfig",
  553. "DogeForCausalLM",
  554. "DogeModel",
  555. "DogePreTrainedModel",
  556. "DogeForSequenceClassification",
  557. ]