modeling_doge.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/doge/modular_doge.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_doge.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # The Doge family of small language models is trained by SmallDoge Team.
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. import math
  23. from collections.abc import Callable
  24. from typing import Optional, Union
  25. import torch
  26. import torch.nn.functional as F
  27. from torch import nn
  28. from ... import initialization as init
  29. from ...activations import ACT2FN
  30. from ...cache_utils import Cache, DynamicCache
  31. from ...generation import GenerationMixin
  32. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
  33. from ...integrations.flex_attention import compile_friendly_flex_attention
  34. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  35. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  36. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import AttentionInterface, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import OutputRecorder, capture_outputs
  43. from .configuration_doge import DogeConfig
  44. if is_torch_flex_attn_available():
  45. from torch.nn.attention.flex_attention import BlockMask
  46. @use_kernel_forward_from_hub("RMSNorm")
  47. class DogeRMSNorm(nn.Module):
  48. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  49. """
  50. DogeRMSNorm is equivalent to T5LayerNorm
  51. """
  52. super().__init__()
  53. self.weight = nn.Parameter(torch.ones(hidden_size))
  54. self.variance_epsilon = eps
  55. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  56. input_dtype = hidden_states.dtype
  57. hidden_states = hidden_states.to(torch.float32)
  58. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  59. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  60. return self.weight * hidden_states.to(input_dtype)
  61. def extra_repr(self):
  62. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  63. class DogeRotaryEmbedding(nn.Module):
  64. inv_freq: torch.Tensor # fix linting for `register_buffer`
  65. def __init__(self, config: DogeConfig, device=None):
  66. super().__init__()
  67. self.max_seq_len_cached = config.max_position_embeddings
  68. self.original_max_seq_len = config.max_position_embeddings
  69. self.config = config
  70. self.rope_type = self.config.rope_parameters["rope_type"]
  71. rope_init_fn: Callable = self.compute_default_rope_parameters
  72. if self.rope_type != "default":
  73. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  74. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  75. self.register_buffer("inv_freq", inv_freq, persistent=False)
  76. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  77. @staticmethod
  78. def compute_default_rope_parameters(
  79. config: DogeConfig | None = None,
  80. device: Optional["torch.device"] = None,
  81. seq_len: int | None = None,
  82. ) -> tuple["torch.Tensor", float]:
  83. """
  84. Computes the inverse frequencies according to the original RoPE implementation
  85. Args:
  86. config ([`~transformers.PreTrainedConfig`]):
  87. The model configuration.
  88. device (`torch.device`):
  89. The device to use for initialization of the inverse frequencies.
  90. seq_len (`int`, *optional*):
  91. The current sequence length. Unused for this type of RoPE.
  92. Returns:
  93. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  94. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  95. """
  96. base = config.rope_parameters["rope_theta"]
  97. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  98. attention_factor = 1.0 # Unused in this type of RoPE
  99. # Compute the inverse frequencies
  100. inv_freq = 1.0 / (
  101. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  102. )
  103. return inv_freq, attention_factor
  104. @torch.no_grad()
  105. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  106. def forward(self, x, position_ids):
  107. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  108. position_ids_expanded = position_ids[:, None, :].float()
  109. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  110. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  111. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  112. emb = torch.cat((freqs, freqs), dim=-1)
  113. cos = emb.cos() * self.attention_scaling
  114. sin = emb.sin() * self.attention_scaling
  115. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  116. def rotate_half(x):
  117. """Rotates half the hidden dims of the input."""
  118. x1 = x[..., : x.shape[-1] // 2]
  119. x2 = x[..., x.shape[-1] // 2 :]
  120. return torch.cat((-x2, x1), dim=-1)
  121. @use_kernel_func_from_hub("rotary_pos_emb")
  122. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  123. """Applies Rotary Position Embedding to the query and key tensors.
  124. Args:
  125. q (`torch.Tensor`): The query tensor.
  126. k (`torch.Tensor`): The key tensor.
  127. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  128. sin (`torch.Tensor`): The sine part of the rotary embedding.
  129. unsqueeze_dim (`int`, *optional*, defaults to 1):
  130. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  131. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  132. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  133. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  134. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  135. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  136. Returns:
  137. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  138. """
  139. cos = cos.unsqueeze(unsqueeze_dim)
  140. sin = sin.unsqueeze(unsqueeze_dim)
  141. q_embed = (q * cos) + (rotate_half(q) * sin)
  142. k_embed = (k * cos) + (rotate_half(k) * sin)
  143. return q_embed, k_embed
  144. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  145. """
  146. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  147. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  148. """
  149. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  150. if n_rep == 1:
  151. return hidden_states
  152. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  153. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  154. def eager_attention_forward(
  155. module: nn.Module,
  156. query: torch.Tensor,
  157. key: torch.Tensor,
  158. value: torch.Tensor,
  159. attention_mask: torch.Tensor | None,
  160. scaling: float,
  161. dropout: float = 0.0,
  162. **kwargs: Unpack[TransformersKwargs],
  163. ):
  164. key_states = repeat_kv(key, module.num_key_value_groups)
  165. value_states = repeat_kv(value, module.num_key_value_groups)
  166. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  167. if attention_mask is not None:
  168. attn_weights = attn_weights + attention_mask
  169. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  170. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  171. attn_output = torch.matmul(attn_weights, value_states)
  172. attn_output = attn_output.transpose(1, 2).contiguous()
  173. return attn_output, attn_weights
  174. def flex_attention_forward(
  175. module: nn.Module,
  176. query: torch.Tensor,
  177. key: torch.Tensor,
  178. value: torch.Tensor,
  179. attention_mask: Union[torch.Tensor, "BlockMask"],
  180. scaling: float | None = None,
  181. softcap: float | None = None,
  182. **kwargs,
  183. ) -> tuple[torch.Tensor, torch.Tensor]:
  184. block_mask = None
  185. causal_mask = None
  186. if isinstance(attention_mask, BlockMask):
  187. block_mask = attention_mask
  188. else:
  189. causal_mask = attention_mask
  190. if causal_mask is not None:
  191. causal_mask = causal_mask[:, :, :, : key.shape[-2]]
  192. def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
  193. if softcap is not None:
  194. score = softcap * torch.tanh(score / softcap)
  195. if causal_mask is not None:
  196. score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx]
  197. return score
  198. attn_output, attention_weights = compile_friendly_flex_attention(
  199. query,
  200. key,
  201. value,
  202. score_mod=score_mod,
  203. block_mask=block_mask,
  204. enable_gqa=True,
  205. scale=scaling,
  206. # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
  207. # For simplification, we thus always return it as no additional computations are introduced.
  208. return_lse=True,
  209. )
  210. # lse is returned in float32
  211. attention_weights = attention_weights.to(value.dtype)
  212. attn_output = attn_output.transpose(1, 2).contiguous()
  213. return attn_output, attention_weights
  214. ALL_ATTENTION_FUNCTIONS = AttentionInterface()
  215. ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward
  216. class DogeAttention(nn.Module):
  217. def __init__(self, config: DogeConfig, layer_idx: int | None = None):
  218. super().__init__()
  219. self.config = config
  220. self.layer_idx = layer_idx
  221. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  222. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  223. self.scaling = self.head_dim**-0.5
  224. self.attention_dropout = config.attention_dropout
  225. self.keep_window_size = config.keep_window_size
  226. self.q_proj = nn.Linear(
  227. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  228. )
  229. self.k_proj = nn.Linear(
  230. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  231. )
  232. self.v_proj = nn.Linear(
  233. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  234. )
  235. # dynamic mask for the QK^T attention weights matrix
  236. self.A = nn.Parameter(torch.zeros(config.num_key_value_heads))
  237. self.dt_proj = nn.Linear(
  238. config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias
  239. )
  240. self.o_proj = nn.Linear(
  241. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  242. )
  243. self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  244. self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  249. attention_mask: torch.Tensor | None = None,
  250. past_key_values: Cache | None = None,
  251. **kwargs,
  252. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  253. input_shape = hidden_states.shape[:-1]
  254. hidden_shape = (*input_shape, -1, self.head_dim)
  255. query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  256. key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  257. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  258. cos, sin = position_embeddings
  259. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  260. if past_key_values is not None:
  261. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  262. # calculate dynamic mask from value_states
  263. dt_states = self.dt_proj(
  264. value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
  265. )
  266. dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
  267. attn_mask = self.prepare_dynamic_mask(
  268. hidden_states=hidden_states,
  269. dt_states=dt_states,
  270. keep_window_size=self.keep_window_size,
  271. attention_mask=attention_mask,
  272. )
  273. attn_mask = repeat_kv(attn_mask, self.num_key_value_groups)
  274. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  275. self.config._attn_implementation, eager_attention_forward
  276. )
  277. attn_output, attn_weights = attention_interface(
  278. self,
  279. query_states,
  280. key_states,
  281. value_states,
  282. attention_mask=attn_mask,
  283. dropout=0.0 if not self.training else self.attention_dropout,
  284. scaling=self.scaling,
  285. **kwargs,
  286. )
  287. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  288. attn_output = self.o_proj(attn_output)
  289. return attn_output, attn_weights
  290. def prepare_dynamic_mask(
  291. self,
  292. hidden_states: torch.Tensor,
  293. dt_states: torch.Tensor,
  294. keep_window_size: int = 2048,
  295. attention_mask: torch.Tensor | None = None,
  296. ):
  297. """
  298. The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention.
  299. Combine `dt_states` with `attention_mask` to generate the final `attn_mask`.
  300. Args:
  301. hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
  302. dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`.
  303. keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
  304. attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
  305. """
  306. min_dtype = torch.finfo(hidden_states.dtype).min
  307. dtype = hidden_states.dtype
  308. attn_mask = dt_states[:, :, None, :].expand(
  309. -1, -1, hidden_states.shape[1], -1
  310. ) # [batch_size, num_heads, query_len, key_len]
  311. if attention_mask is not None and not isinstance(attention_mask, BlockMask):
  312. if attention_mask.dtype == torch.bool:
  313. dtype = hidden_states.dtype
  314. attention_mask = torch.where(
  315. attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype
  316. )
  317. attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype)
  318. if attn_mask.shape[-1] > keep_window_size:
  319. active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device)
  320. topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, largest=True, sorted=False).indices
  321. active_mask = active_mask.scatter(-1, topk_indices, 1.0)
  322. attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype)
  323. return attn_mask
  324. class DogeMLP(nn.Module):
  325. def __init__(self, config):
  326. super().__init__()
  327. self.config = config
  328. self.hidden_size = config.hidden_size
  329. self.intermediate_size = config.intermediate_size
  330. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  331. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  332. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  333. self.act_fn = ACT2FN[config.hidden_act]
  334. def forward(self, x):
  335. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  336. return down_proj
  337. class DogeCDMoE(nn.Module):
  338. def __init__(self, config: DogeConfig):
  339. super().__init__()
  340. self.hidden_size = config.hidden_size
  341. self.intermediate_size = config.intermediate_size
  342. self.act_fn = ACT2FN[config.hidden_act]
  343. self.num_experts = config.num_experts
  344. self.num_keys = math.floor(math.sqrt(self.num_experts))
  345. self.top_k = config.num_experts_per_tok
  346. self.norm_topk_prob = config.norm_topk_prob
  347. # shared expert
  348. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  349. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  350. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  351. # router gate for retrieval experts
  352. self.router_gate = nn.Linear(self.hidden_size, self.num_keys * 2, bias=False)
  353. # routed experts
  354. self.down_embed = nn.Embedding(self.num_experts, self.hidden_size)
  355. self.up_embed = nn.Embedding(self.num_experts, self.hidden_size)
  356. def forward(
  357. self,
  358. hidden_states: torch.Tensor,
  359. **kwargs,
  360. ) -> torch.Tensor:
  361. bsz, seq_len, _ = hidden_states.shape
  362. # get routing logits with router gate
  363. router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
  364. # get experts with the highest routing logits
  365. (scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1)
  366. all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
  367. all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
  368. all_scores = all_scores.view(*all_scores.shape[:-2], -1)
  369. all_indices = all_indices.view(*all_indices.shape[:-2], -1)
  370. scores, position_indices = all_scores.topk(self.top_k, dim=-1)
  371. indices = all_indices.gather(-1, position_indices)
  372. routing_weights = F.softmax(scores, dim=-1)
  373. if self.norm_topk_prob:
  374. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  375. # mix routed experts states with shared expert states
  376. down_embed = self.down_embed(indices)
  377. up_embed = self.up_embed(indices)
  378. experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1)
  379. experts_weights = self.act_fn(experts_weights) * routing_weights
  380. experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
  381. hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
  382. hidden_states = hidden_states + experts_states
  383. return hidden_states, router_logits
  384. class DogeDecoderLayer(GradientCheckpointingLayer):
  385. def __init__(self, config: DogeConfig, layer_idx: int | None = None):
  386. super().__init__()
  387. self.hidden_dropout = config.hidden_dropout
  388. self.input_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  389. self.self_attn = DogeAttention(config=config, layer_idx=layer_idx)
  390. self.input_residual = nn.Parameter(torch.ones(config.hidden_size))
  391. self.post_attention_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  392. self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config)
  393. self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size))
  394. def forward(
  395. self,
  396. hidden_states: torch.Tensor,
  397. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  398. attention_mask: torch.Tensor | None = None,
  399. position_ids: torch.LongTensor | None = None,
  400. past_key_values: Cache | None = None,
  401. use_cache: bool | None = False,
  402. **kwargs: Unpack[TransformersKwargs],
  403. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  404. # sequence transformation
  405. residual = hidden_states
  406. hidden_states = self.input_layernorm(hidden_states)
  407. hidden_states, self_attn_weights = self.self_attn(
  408. hidden_states=hidden_states,
  409. position_embeddings=position_embeddings,
  410. attention_mask=attention_mask,
  411. position_ids=position_ids,
  412. past_key_values=past_key_values,
  413. use_cache=use_cache,
  414. **kwargs,
  415. )
  416. hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
  417. hidden_states = self.input_residual * residual + hidden_states
  418. # state transformation
  419. residual = hidden_states
  420. hidden_states = self.post_attention_layernorm(hidden_states)
  421. hidden_states = self.mlp(hidden_states)
  422. hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
  423. hidden_states = self.post_attention_residual * residual + hidden_states
  424. return hidden_states
  425. @auto_docstring
  426. class DogePreTrainedModel(PreTrainedModel):
  427. config: DogeConfig
  428. base_model_prefix = "model"
  429. supports_gradient_checkpointing = True
  430. _no_split_modules = ["DogeDecoderLayer"]
  431. _skip_keys_device_placement = ["past_key_values"]
  432. _supports_flash_attn = False
  433. _supports_sdpa = True
  434. _supports_flex_attn = True
  435. _can_compile_fullgraph = False
  436. _supports_attention_backend = True
  437. _can_record_outputs = {
  438. "router_logits": OutputRecorder(DogeCDMoE, index=1),
  439. "hidden_states": DogeDecoderLayer,
  440. "attentions": DogeAttention,
  441. }
  442. @torch.no_grad()
  443. def _init_weights(self, module):
  444. """Initialize the weights"""
  445. super()._init_weights(module)
  446. if isinstance(module, DogeAttention):
  447. if hasattr(module, "A"):
  448. init.zeros_(module.A)
  449. elif isinstance(module, DogeDecoderLayer):
  450. if hasattr(module, "input_residual"):
  451. init.ones_(module.input_residual)
  452. if hasattr(module, "post_attention_residual"):
  453. init.ones_(module.post_attention_residual)
  454. @auto_docstring
  455. class DogeModel(DogePreTrainedModel):
  456. def __init__(self, config: DogeConfig):
  457. super().__init__(config)
  458. self.padding_idx = config.pad_token_id
  459. self.vocab_size = config.vocab_size
  460. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  461. self.layers = nn.ModuleList(
  462. [DogeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  463. )
  464. self.norm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  465. self.rotary_emb = DogeRotaryEmbedding(config=config)
  466. self.gradient_checkpointing = False
  467. # Initialize weights and apply final processing
  468. self.post_init()
  469. @merge_with_config_defaults
  470. @capture_outputs
  471. @auto_docstring
  472. def forward(
  473. self,
  474. input_ids: torch.LongTensor | None = None,
  475. attention_mask: torch.Tensor | None = None,
  476. position_ids: torch.LongTensor | None = None,
  477. past_key_values: Cache | None = None,
  478. inputs_embeds: torch.FloatTensor | None = None,
  479. use_cache: bool | None = None,
  480. **kwargs: Unpack[TransformersKwargs],
  481. ) -> MoeModelOutputWithPast:
  482. if (input_ids is None) ^ (inputs_embeds is not None):
  483. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  484. if use_cache and past_key_values is None:
  485. past_key_values = DynamicCache(config=self.config)
  486. if inputs_embeds is None:
  487. inputs_embeds = self.embed_tokens(input_ids)
  488. if position_ids is None:
  489. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  490. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  491. position_ids = position_ids.unsqueeze(0)
  492. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  493. causal_mask = mask_function(
  494. config=self.config,
  495. inputs_embeds=inputs_embeds,
  496. attention_mask=attention_mask,
  497. past_key_values=past_key_values,
  498. position_ids=position_ids,
  499. )
  500. hidden_states = inputs_embeds
  501. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  502. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  503. hidden_states = decoder_layer(
  504. hidden_states,
  505. attention_mask=causal_mask,
  506. position_ids=position_ids,
  507. past_key_values=past_key_values,
  508. use_cache=use_cache,
  509. position_embeddings=position_embeddings,
  510. **kwargs,
  511. )
  512. hidden_states = self.norm(hidden_states)
  513. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  514. last_hidden_state=hidden_states,
  515. past_key_values=past_key_values,
  516. )
  517. def load_balancing_loss_func(
  518. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  519. num_experts: int | None = None,
  520. num_keys: int | None = None,
  521. top_k: int = 2,
  522. attention_mask: torch.Tensor | None = None,
  523. ) -> torch.Tensor | int:
  524. r"""
  525. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  526. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  527. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  528. experts is too unbalanced.
  529. Args:
  530. gate_logits:
  531. Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of
  532. shape [2, batch_size * sequence_length, num_keys].
  533. num_experts:
  534. Number of experts
  535. num_keys:
  536. Number of keys
  537. top_k:
  538. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  539. parameter.
  540. attention_mask (`torch.Tensor`, *optional*):
  541. The attention_mask used in forward function
  542. shape [batch_size X sequence_length] if not None.
  543. Returns:
  544. The auxiliary loss.
  545. """
  546. if gate_logits is None or not isinstance(gate_logits, tuple):
  547. return 0
  548. compute_dtype = gate_logits[0].dtype
  549. compute_device = gate_logits[0].device
  550. all_expert_indices = []
  551. all_routing_weights = []
  552. for layer_gate_logits in gate_logits:
  553. layer_gate_logits = layer_gate_logits.to(compute_device)
  554. (scores_x, scores_y), (indices_x, indices_y) = layer_gate_logits.topk(num_keys, dim=-1)
  555. all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
  556. all_indices = indices_x.unsqueeze(-1) * num_keys + indices_y.unsqueeze(-2)
  557. all_scores = all_scores.view(*all_scores.shape[:-2], -1)
  558. all_indices = all_indices.view(*all_indices.shape[:-2], -1)
  559. _, position_indices = all_scores.topk(top_k, dim=-1)
  560. expert_indices = all_indices.gather(-1, position_indices)
  561. routing_weights = F.softmax(all_scores, dim=-1)
  562. all_expert_indices.append(expert_indices)
  563. all_routing_weights.append(routing_weights)
  564. all_expert_indices = torch.cat(all_expert_indices, dim=0)
  565. all_routing_weights = torch.cat(all_routing_weights, dim=0)
  566. if attention_mask is None:
  567. # Compute the percentage of tokens routed to each experts
  568. all_expert_indices = all_expert_indices.view(-1)
  569. tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
  570. pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
  571. tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / all_expert_indices.shape[0]
  572. # Compute the average probability of routing to these experts
  573. router_prob_per_expert = torch.mean(all_routing_weights, dim=0)
  574. else:
  575. batch_size, sequence_length = attention_mask.shape
  576. num_hidden_layers = len(gate_logits)
  577. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  578. expert_attention_mask = (
  579. attention_mask[None, :, :, None]
  580. .expand((num_hidden_layers, batch_size, sequence_length, top_k))
  581. .reshape(-1)
  582. .to(compute_device)
  583. )
  584. all_expert_indices = all_expert_indices.view(-1)[expert_attention_mask.bool()]
  585. # Compute the percentage of tokens routed to each experts
  586. tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
  587. pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
  588. tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / torch.sum(
  589. expert_attention_mask
  590. )
  591. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  592. router_per_expert_attention_mask = (
  593. attention_mask[None, :, :, None]
  594. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  595. .reshape(-1, num_experts)
  596. .to(compute_device)
  597. )
  598. # Compute the average probability of routing to these experts
  599. router_prob_per_expert = torch.sum(all_routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  600. router_per_expert_attention_mask, dim=0
  601. )
  602. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
  603. return overall_loss * num_experts
  604. @auto_docstring
  605. class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
  606. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  607. _tp_plan = {"lm_head": "colwise_gather_output"}
  608. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  609. def __init__(self, config):
  610. super().__init__(config)
  611. self.model = DogeModel(config)
  612. self.vocab_size = config.vocab_size
  613. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  614. self.router_aux_loss_coef = config.router_aux_loss_coef
  615. self.num_experts = config.num_experts
  616. self.num_experts_per_tok = config.num_experts_per_tok
  617. # Initialize weights and apply final processing
  618. self.post_init()
  619. @can_return_tuple
  620. @auto_docstring
  621. def forward(
  622. self,
  623. input_ids: torch.LongTensor | None = None,
  624. attention_mask: torch.Tensor | None = None,
  625. position_ids: torch.LongTensor | None = None,
  626. past_key_values: Cache | None = None,
  627. inputs_embeds: torch.FloatTensor | None = None,
  628. labels: torch.LongTensor | None = None,
  629. use_cache: bool | None = None,
  630. logits_to_keep: int | torch.Tensor = 0,
  631. output_router_logits: bool | None = None,
  632. **kwargs: Unpack[TransformersKwargs],
  633. ) -> MoeCausalLMOutputWithPast:
  634. r"""
  635. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  636. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  637. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  638. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  639. Example:
  640. ```python
  641. >>> from transformers import AutoTokenizer, DogeForCausalLM
  642. >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M")
  643. >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M")
  644. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  645. >>> inputs = tokenizer(prompt, return_tensors="pt")
  646. >>> # Generate
  647. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  648. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  649. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  650. ```"""
  651. output_router_logits = (
  652. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  653. )
  654. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  655. outputs: MoeModelOutputWithPast = self.model(
  656. input_ids=input_ids,
  657. attention_mask=attention_mask,
  658. position_ids=position_ids,
  659. past_key_values=past_key_values,
  660. inputs_embeds=inputs_embeds,
  661. use_cache=use_cache,
  662. **kwargs,
  663. )
  664. hidden_states = outputs.last_hidden_state
  665. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  666. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  667. logits = self.lm_head(hidden_states[:, slice_indices, :])
  668. loss = None
  669. if labels is not None:
  670. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  671. aux_loss = None
  672. if output_router_logits:
  673. aux_loss = load_balancing_loss_func(
  674. outputs.router_logits,
  675. self.num_experts,
  676. math.floor(math.sqrt(self.num_experts)),
  677. self.num_experts_per_tok,
  678. attention_mask,
  679. )
  680. if labels is not None:
  681. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  682. return MoeCausalLMOutputWithPast(
  683. loss=loss,
  684. aux_loss=aux_loss,
  685. logits=logits,
  686. past_key_values=outputs.past_key_values,
  687. hidden_states=outputs.hidden_states,
  688. attentions=outputs.attentions,
  689. router_logits=outputs.router_logits,
  690. )
  691. class DogeForSequenceClassification(GenericForSequenceClassification, DogePreTrainedModel):
  692. pass
  693. __all__ = ["DogeForCausalLM", "DogeModel", "DogePreTrainedModel", "DogeForSequenceClassification"]