modeling_dots1.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/dots1/modular_dots1.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_dots1.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import (
  30. use_experts_implementation,
  31. use_kernel_forward_from_hub,
  32. use_kernel_func_from_hub,
  33. use_kernelized_func,
  34. )
  35. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  36. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  37. from ...modeling_layers import GradientCheckpointingLayer
  38. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  39. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  40. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  41. from ...processing_utils import Unpack
  42. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  43. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  44. from ...utils.output_capturing import capture_outputs
  45. from .configuration_dots1 import Dots1Config
  46. @use_kernel_forward_from_hub("RMSNorm")
  47. class Dots1RMSNorm(nn.Module):
  48. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  49. """
  50. Dots1RMSNorm is equivalent to T5LayerNorm
  51. """
  52. super().__init__()
  53. self.weight = nn.Parameter(torch.ones(hidden_size))
  54. self.variance_epsilon = eps
  55. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  56. input_dtype = hidden_states.dtype
  57. hidden_states = hidden_states.to(torch.float32)
  58. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  59. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  60. return self.weight * hidden_states.to(input_dtype)
  61. def extra_repr(self):
  62. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  63. class Dots1RotaryEmbedding(nn.Module):
  64. inv_freq: torch.Tensor # fix linting for `register_buffer`
  65. def __init__(self, config: Dots1Config, device=None):
  66. super().__init__()
  67. self.max_seq_len_cached = config.max_position_embeddings
  68. self.original_max_seq_len = config.max_position_embeddings
  69. self.config = config
  70. self.rope_type = self.config.rope_parameters["rope_type"]
  71. rope_init_fn: Callable = self.compute_default_rope_parameters
  72. if self.rope_type != "default":
  73. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  74. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  75. self.register_buffer("inv_freq", inv_freq, persistent=False)
  76. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  77. @staticmethod
  78. def compute_default_rope_parameters(
  79. config: Dots1Config | None = None,
  80. device: Optional["torch.device"] = None,
  81. seq_len: int | None = None,
  82. ) -> tuple["torch.Tensor", float]:
  83. """
  84. Computes the inverse frequencies according to the original RoPE implementation
  85. Args:
  86. config ([`~transformers.PreTrainedConfig`]):
  87. The model configuration.
  88. device (`torch.device`):
  89. The device to use for initialization of the inverse frequencies.
  90. seq_len (`int`, *optional*):
  91. The current sequence length. Unused for this type of RoPE.
  92. Returns:
  93. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  94. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  95. """
  96. base = config.rope_parameters["rope_theta"]
  97. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  98. attention_factor = 1.0 # Unused in this type of RoPE
  99. # Compute the inverse frequencies
  100. inv_freq = 1.0 / (
  101. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  102. )
  103. return inv_freq, attention_factor
  104. @torch.no_grad()
  105. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  106. def forward(self, x, position_ids):
  107. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  108. position_ids_expanded = position_ids[:, None, :].float()
  109. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  110. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  111. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  112. emb = torch.cat((freqs, freqs), dim=-1)
  113. cos = emb.cos() * self.attention_scaling
  114. sin = emb.sin() * self.attention_scaling
  115. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  116. def rotate_half(x):
  117. """Rotates half the hidden dims of the input."""
  118. x1 = x[..., : x.shape[-1] // 2]
  119. x2 = x[..., x.shape[-1] // 2 :]
  120. return torch.cat((-x2, x1), dim=-1)
  121. @use_kernel_func_from_hub("rotary_pos_emb")
  122. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  123. """Applies Rotary Position Embedding to the query and key tensors.
  124. Args:
  125. q (`torch.Tensor`): The query tensor.
  126. k (`torch.Tensor`): The key tensor.
  127. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  128. sin (`torch.Tensor`): The sine part of the rotary embedding.
  129. unsqueeze_dim (`int`, *optional*, defaults to 1):
  130. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  131. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  132. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  133. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  134. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  135. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  136. Returns:
  137. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  138. """
  139. cos = cos.unsqueeze(unsqueeze_dim)
  140. sin = sin.unsqueeze(unsqueeze_dim)
  141. q_embed = (q * cos) + (rotate_half(q) * sin)
  142. k_embed = (k * cos) + (rotate_half(k) * sin)
  143. return q_embed, k_embed
  144. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  145. """
  146. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  147. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  148. """
  149. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  150. if n_rep == 1:
  151. return hidden_states
  152. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  153. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  154. def eager_attention_forward(
  155. module: nn.Module,
  156. query: torch.Tensor,
  157. key: torch.Tensor,
  158. value: torch.Tensor,
  159. attention_mask: torch.Tensor | None,
  160. scaling: float,
  161. dropout: float = 0.0,
  162. **kwargs: Unpack[TransformersKwargs],
  163. ):
  164. key_states = repeat_kv(key, module.num_key_value_groups)
  165. value_states = repeat_kv(value, module.num_key_value_groups)
  166. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  167. if attention_mask is not None:
  168. attn_weights = attn_weights + attention_mask
  169. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  170. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  171. attn_output = torch.matmul(attn_weights, value_states)
  172. attn_output = attn_output.transpose(1, 2).contiguous()
  173. return attn_output, attn_weights
  174. @use_kernelized_func(apply_rotary_pos_emb)
  175. class Dots1Attention(nn.Module):
  176. """Multi-headed attention from 'Attention Is All You Need' paper"""
  177. def __init__(self, config: Dots1Config, layer_idx: int):
  178. super().__init__()
  179. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  180. self.config = config
  181. self.layer_idx = layer_idx
  182. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  183. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  184. self.scaling = self.head_dim**-0.5
  185. self.attention_dropout = config.attention_dropout
  186. self.is_causal = True
  187. self.q_proj = nn.Linear(
  188. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  189. )
  190. self.k_proj = nn.Linear(
  191. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  192. )
  193. self.v_proj = nn.Linear(
  194. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  195. )
  196. self.o_proj = nn.Linear(
  197. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  198. )
  199. self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
  200. self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
  201. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  202. def forward(
  203. self,
  204. hidden_states: torch.Tensor,
  205. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  206. attention_mask: torch.Tensor | None,
  207. past_key_values: Cache | None = None,
  208. **kwargs: Unpack[FlashAttentionKwargs],
  209. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  210. input_shape = hidden_states.shape[:-1]
  211. hidden_shape = (*input_shape, -1, self.head_dim)
  212. query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  213. key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  214. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  215. cos, sin = position_embeddings
  216. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  217. if past_key_values is not None:
  218. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  219. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  220. self.config._attn_implementation, eager_attention_forward
  221. )
  222. attn_output, attn_weights = attention_interface(
  223. self,
  224. query_states,
  225. key_states,
  226. value_states,
  227. attention_mask,
  228. dropout=0.0 if not self.training else self.attention_dropout,
  229. scaling=self.scaling,
  230. sliding_window=self.sliding_window, # diff with Llama
  231. **kwargs,
  232. )
  233. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  234. attn_output = self.o_proj(attn_output)
  235. return attn_output, attn_weights
  236. class Dots1MLP(nn.Module):
  237. def __init__(self, config, intermediate_size=None):
  238. super().__init__()
  239. self.config = config
  240. self.hidden_size = config.hidden_size
  241. self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
  242. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  243. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  244. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  245. self.act_fn = ACT2FN[config.hidden_act]
  246. def forward(self, x):
  247. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  248. return down_proj
  249. class Dots1TopkRouter(nn.Module):
  250. def __init__(self, config):
  251. super().__init__()
  252. self.config = config
  253. self.n_routed_experts = config.n_routed_experts
  254. self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
  255. self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
  256. def forward(self, hidden_states):
  257. hidden_states = hidden_states.view(-1, self.config.hidden_size)
  258. router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
  259. return router_logits
  260. @use_experts_implementation
  261. class Dots1NaiveMoe(nn.Module):
  262. """Collection of expert weights stored as 3D tensors."""
  263. def __init__(self, config):
  264. super().__init__()
  265. self.num_experts = config.num_local_experts
  266. self.hidden_dim = config.hidden_size
  267. self.intermediate_dim = config.moe_intermediate_size
  268. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  269. self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  270. self.act_fn = ACT2FN[config.hidden_act]
  271. def forward(
  272. self,
  273. hidden_states: torch.Tensor,
  274. top_k_index: torch.Tensor,
  275. top_k_weights: torch.Tensor,
  276. ) -> torch.Tensor:
  277. final_hidden_states = torch.zeros_like(hidden_states)
  278. with torch.no_grad():
  279. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  280. expert_mask = expert_mask.permute(2, 1, 0)
  281. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  282. for expert_idx in expert_hit:
  283. expert_idx = expert_idx[0]
  284. if expert_idx == self.num_experts:
  285. continue
  286. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  287. current_state = hidden_states[token_idx]
  288. gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  289. current_hidden_states = self.act_fn(gate) * up
  290. current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  291. current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  292. final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  293. return final_hidden_states
  294. class Dots1MoE(nn.Module):
  295. """
  296. A mixed expert module containing shared experts.
  297. """
  298. def __init__(self, config):
  299. super().__init__()
  300. self.config = config
  301. self.experts = Dots1NaiveMoe(config)
  302. self.gate = Dots1TopkRouter(config)
  303. self.shared_experts = Dots1MLP(
  304. config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
  305. )
  306. self.n_routed_experts = config.n_routed_experts
  307. self.n_group = config.n_group
  308. self.topk_group = config.topk_group
  309. self.norm_topk_prob = config.norm_topk_prob
  310. self.routed_scaling_factor = config.routed_scaling_factor
  311. self.top_k = config.num_experts_per_tok
  312. def route_tokens_to_experts(self, router_logits):
  313. router_logits = router_logits.sigmoid() # main diff with deepseekv3
  314. router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
  315. group_scores = (
  316. router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
  317. .topk(2, dim=-1)[0]
  318. .sum(dim=-1)
  319. )
  320. group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
  321. group_mask = torch.zeros_like(group_scores)
  322. group_mask.scatter_(1, group_idx, 1)
  323. score_mask = (
  324. group_mask.unsqueeze(-1)
  325. .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
  326. .reshape(-1, self.n_routed_experts)
  327. )
  328. scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
  329. topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
  330. topk_weights = router_logits.gather(1, topk_indices)
  331. if self.norm_topk_prob:
  332. denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
  333. topk_weights /= denominator
  334. topk_weights = topk_weights * self.routed_scaling_factor
  335. return topk_indices, topk_weights
  336. def forward(self, hidden_states):
  337. residuals = hidden_states
  338. orig_shape = hidden_states.shape
  339. router_logits = self.gate(hidden_states)
  340. topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
  341. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  342. hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
  343. hidden_states = hidden_states + self.shared_experts(residuals)
  344. return hidden_states
  345. class Dots1DecoderLayer(GradientCheckpointingLayer):
  346. def __init__(self, config: Dots1Config, layer_idx: int):
  347. super().__init__()
  348. self.hidden_size = config.hidden_size
  349. self.self_attn = Dots1Attention(config=config, layer_idx=layer_idx)
  350. if layer_idx >= config.first_k_dense_replace:
  351. self.mlp = Dots1MoE(config)
  352. else:
  353. self.mlp = Dots1MLP(config)
  354. self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  355. self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  356. def forward(
  357. self,
  358. hidden_states: torch.Tensor,
  359. attention_mask: torch.Tensor | None = None,
  360. position_ids: torch.LongTensor | None = None,
  361. past_key_values: Cache | None = None,
  362. use_cache: bool | None = False,
  363. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  364. **kwargs: Unpack[TransformersKwargs],
  365. ) -> torch.Tensor:
  366. residual = hidden_states
  367. hidden_states = self.input_layernorm(hidden_states)
  368. # Self Attention
  369. hidden_states, _ = self.self_attn(
  370. hidden_states=hidden_states,
  371. attention_mask=attention_mask,
  372. position_ids=position_ids,
  373. past_key_values=past_key_values,
  374. use_cache=use_cache,
  375. position_embeddings=position_embeddings,
  376. **kwargs,
  377. )
  378. hidden_states = residual + hidden_states
  379. # Fully Connected
  380. residual = hidden_states
  381. hidden_states = self.post_attention_layernorm(hidden_states)
  382. hidden_states = self.mlp(hidden_states)
  383. hidden_states = residual + hidden_states
  384. return hidden_states
  385. @auto_docstring
  386. class Dots1PreTrainedModel(PreTrainedModel):
  387. config: Dots1Config
  388. base_model_prefix = "model"
  389. supports_gradient_checkpointing = True
  390. _no_split_modules = ["Dots1DecoderLayer"]
  391. _skip_keys_device_placement = ["past_key_values"]
  392. _supports_flash_attn = True
  393. _supports_sdpa = True
  394. _supports_flex_attn = True
  395. _can_compile_fullgraph = True
  396. _supports_attention_backend = True
  397. _can_record_outputs = {
  398. "hidden_states": Dots1DecoderLayer,
  399. "attentions": Dots1Attention,
  400. }
  401. _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
  402. _keys_to_ignore_on_load_unexpected = None
  403. @torch.no_grad()
  404. def _init_weights(self, module):
  405. super()._init_weights(module)
  406. if isinstance(module, Dots1TopkRouter):
  407. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  408. init.zeros_(module.e_score_correction_bias)
  409. elif isinstance(module, Dots1NaiveMoe):
  410. init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
  411. init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
  412. @auto_docstring
  413. class Dots1Model(Dots1PreTrainedModel):
  414. def __init__(self, config: Dots1Config):
  415. super().__init__(config)
  416. self.padding_idx = config.pad_token_id
  417. self.vocab_size = config.vocab_size
  418. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  419. self.layers = nn.ModuleList(
  420. [Dots1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  421. )
  422. self.norm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  423. self.rotary_emb = Dots1RotaryEmbedding(config=config)
  424. self.gradient_checkpointing = False
  425. self.has_sliding_layers = "sliding_attention" in self.config.layer_types
  426. # Initialize weights and apply final processing
  427. self.post_init()
  428. @merge_with_config_defaults
  429. @capture_outputs
  430. @auto_docstring
  431. def forward(
  432. self,
  433. input_ids: torch.LongTensor | None = None,
  434. attention_mask: torch.Tensor | None = None,
  435. position_ids: torch.LongTensor | None = None,
  436. past_key_values: Cache | None = None,
  437. inputs_embeds: torch.FloatTensor | None = None,
  438. use_cache: bool | None = None,
  439. **kwargs: Unpack[TransformersKwargs],
  440. ) -> BaseModelOutputWithPast:
  441. if (input_ids is None) ^ (inputs_embeds is not None):
  442. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  443. if inputs_embeds is None:
  444. inputs_embeds = self.embed_tokens(input_ids)
  445. if use_cache and past_key_values is None:
  446. past_key_values = DynamicCache(config=self.config)
  447. if position_ids is None:
  448. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  449. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  450. position_ids = position_ids.unsqueeze(0)
  451. # It may already have been prepared by e.g. `generate`
  452. if not isinstance(causal_mask_mapping := attention_mask, dict):
  453. # Prepare mask arguments
  454. mask_kwargs = {
  455. "config": self.config,
  456. "inputs_embeds": inputs_embeds,
  457. "attention_mask": attention_mask,
  458. "past_key_values": past_key_values,
  459. "position_ids": position_ids,
  460. }
  461. # Create the masks
  462. causal_mask_mapping = {
  463. "full_attention": create_causal_mask(**mask_kwargs),
  464. }
  465. # The sliding window alternating layers are not always activated depending on the config
  466. if self.has_sliding_layers:
  467. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  468. hidden_states = inputs_embeds
  469. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  470. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  471. hidden_states = decoder_layer(
  472. hidden_states,
  473. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  474. position_embeddings=position_embeddings,
  475. position_ids=position_ids,
  476. past_key_values=past_key_values,
  477. use_cache=use_cache,
  478. **kwargs,
  479. )
  480. hidden_states = self.norm(hidden_states)
  481. return BaseModelOutputWithPast(
  482. last_hidden_state=hidden_states,
  483. past_key_values=past_key_values if use_cache else None,
  484. )
  485. @auto_docstring
  486. class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin):
  487. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  488. _tp_plan = {"lm_head": "colwise_gather_output"}
  489. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  490. def __init__(self, config):
  491. super().__init__(config)
  492. self.model = Dots1Model(config)
  493. self.vocab_size = config.vocab_size
  494. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  495. # Initialize weights and apply final processing
  496. self.post_init()
  497. @can_return_tuple
  498. @auto_docstring
  499. def forward(
  500. self,
  501. input_ids: torch.LongTensor | None = None,
  502. attention_mask: torch.Tensor | None = None,
  503. position_ids: torch.LongTensor | None = None,
  504. past_key_values: Cache | None = None,
  505. inputs_embeds: torch.FloatTensor | None = None,
  506. labels: torch.LongTensor | None = None,
  507. use_cache: bool | None = None,
  508. logits_to_keep: int | torch.Tensor = 0,
  509. **kwargs: Unpack[TransformersKwargs],
  510. ) -> CausalLMOutputWithPast:
  511. r"""
  512. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  513. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  514. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  515. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  516. Example:
  517. ```python
  518. >>> from transformers import AutoTokenizer, Dots1ForCausalLM
  519. >>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst")
  520. >>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst")
  521. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  522. >>> inputs = tokenizer(prompt, return_tensors="pt")
  523. >>> # Generate
  524. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  525. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  526. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  527. ```"""
  528. outputs: BaseModelOutputWithPast = self.model(
  529. input_ids=input_ids,
  530. attention_mask=attention_mask,
  531. position_ids=position_ids,
  532. past_key_values=past_key_values,
  533. inputs_embeds=inputs_embeds,
  534. use_cache=use_cache,
  535. **kwargs,
  536. )
  537. hidden_states = outputs.last_hidden_state
  538. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  539. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  540. logits = self.lm_head(hidden_states[:, slice_indices, :])
  541. loss = None
  542. if labels is not None:
  543. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  544. return CausalLMOutputWithPast(
  545. loss=loss,
  546. logits=logits,
  547. past_key_values=outputs.past_key_values,
  548. hidden_states=outputs.hidden_states,
  549. attentions=outputs.attentions,
  550. )
  551. __all__ = ["Dots1PreTrainedModel", "Dots1Model", "Dots1ForCausalLM"]