modular_dots1.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...modeling_outputs import CausalLMOutputWithPast
  18. from ...modeling_rope_utils import RopeParameters
  19. from ...processing_utils import Unpack
  20. from ...utils import auto_docstring, logging
  21. from ..deepseek_v3.modeling_deepseek_v3 import (
  22. DeepseekV3DecoderLayer,
  23. DeepseekV3MLP,
  24. DeepseekV3MoE,
  25. DeepseekV3PreTrainedModel,
  26. DeepseekV3TopkRouter,
  27. )
  28. from ..qwen3.modeling_qwen3 import (
  29. Qwen3Attention,
  30. Qwen3ForCausalLM,
  31. Qwen3Model,
  32. Qwen3RMSNorm,
  33. Qwen3RotaryEmbedding,
  34. TransformersKwargs,
  35. )
  36. logger = logging.get_logger(__name__)
  37. @auto_docstring(checkpoint="rednote-hilab/dots.llm1.base")
  38. @strict
  39. class Dots1Config(PreTrainedConfig):
  40. r"""
  41. n_group (`int`, *optional*, defaults to 1):
  42. Number of groups for routed experts.
  43. first_k_dense_replace (`int`, *optional*, defaults to 0):
  44. Number of dense layers at the beginning of the model before the first MoE layer.
  45. Examples:
  46. ```python
  47. >>> from transformers import Dots1Model, Dots1Config
  48. >>> # Initializing a Dots1 style configuration
  49. >>> configuration = Dots1Config()
  50. >>> # Accessing the model configuration
  51. >>> configuration = model.config
  52. ```
  53. """
  54. model_type = "dots1"
  55. keys_to_ignore_at_inference = ["past_key_values"]
  56. base_model_tp_plan = {
  57. "layers.*.self_attn.q_proj": "colwise",
  58. "layers.*.self_attn.k_proj": "colwise",
  59. "layers.*.self_attn.v_proj": "colwise",
  60. "layers.*.self_attn.o_proj": "rowwise",
  61. "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
  62. "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
  63. "layers.*.mlp.experts.gate_up_proj": "packed_colwise",
  64. "layers.*.mlp.experts.down_proj": "rowwise",
  65. "layers.*.mlp.experts": "moe_tp_experts",
  66. "layers.*.mlp.shared_experts.gate_proj": "colwise",
  67. "layers.*.mlp.shared_experts.up_proj": "colwise",
  68. "layers.*.mlp.shared_experts.down_proj": "rowwise",
  69. "layers.*.mlp.gate_proj": "colwise",
  70. "layers.*.mlp.up_proj": "colwise",
  71. "layers.*.mlp.down_proj": "rowwise",
  72. }
  73. base_model_pp_plan = {
  74. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  75. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  76. "norm": (["hidden_states"], ["hidden_states"]),
  77. }
  78. attribute_map = {
  79. "num_local_experts": "n_routed_experts",
  80. }
  81. vocab_size: int = 152064
  82. hidden_size: int = 4608
  83. intermediate_size: int = 10944
  84. moe_intermediate_size: int = 1408
  85. num_hidden_layers: int = 62
  86. num_attention_heads: int = 32
  87. num_key_value_heads: int | None = 32
  88. n_shared_experts: int | None = None
  89. n_routed_experts: int | None = None
  90. n_group: int | None = 1
  91. topk_group: int | None = 1
  92. num_experts_per_tok: int | None = None
  93. first_k_dense_replace: int | None = 0
  94. norm_topk_prob: bool | None = False
  95. hidden_act: str = "silu"
  96. max_position_embeddings: int = 2048
  97. initializer_range: float = 0.02
  98. rms_norm_eps: float = 1e-6
  99. use_cache: bool = True
  100. tie_word_embeddings: bool = False
  101. rope_parameters: RopeParameters | dict | None = None
  102. attention_bias: bool = False
  103. attention_dropout: float | int | None = 0.0
  104. routed_scaling_factor: float = 1.0
  105. sliding_window: int | None = 4096
  106. max_window_layers: int | None = 62
  107. layer_types: list[str] | None = None
  108. pad_token_id: int | None = None
  109. bos_token_id: int | None = None
  110. eos_token_id: int | list[int] | None = None
  111. def __post_init__(self, **kwargs):
  112. if self.num_key_value_heads is None:
  113. self.num_key_value_heads = self.num_attention_heads
  114. if self.layer_types is None:
  115. self.layer_types = [
  116. "sliding_attention"
  117. if self.sliding_window is not None and i >= self.max_window_layers
  118. else "full_attention"
  119. for i in range(self.num_hidden_layers)
  120. ]
  121. super().__post_init__(**kwargs)
  122. class Dots1RMSNorm(Qwen3RMSNorm):
  123. pass
  124. class Dots1RotaryEmbedding(Qwen3RotaryEmbedding):
  125. pass
  126. class Dots1Attention(Qwen3Attention):
  127. pass
  128. class Dots1MLP(DeepseekV3MLP):
  129. pass
  130. class Dots1TopkRouter(DeepseekV3TopkRouter):
  131. pass
  132. class Dots1MoE(DeepseekV3MoE):
  133. def route_tokens_to_experts(self, router_logits):
  134. router_logits = router_logits.sigmoid() # main diff with deepseekv3
  135. router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
  136. group_scores = (
  137. router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
  138. .topk(2, dim=-1)[0]
  139. .sum(dim=-1)
  140. )
  141. group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
  142. group_mask = torch.zeros_like(group_scores)
  143. group_mask.scatter_(1, group_idx, 1)
  144. score_mask = (
  145. group_mask.unsqueeze(-1)
  146. .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
  147. .reshape(-1, self.n_routed_experts)
  148. )
  149. scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
  150. topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
  151. topk_weights = router_logits.gather(1, topk_indices)
  152. if self.norm_topk_prob:
  153. denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
  154. topk_weights /= denominator
  155. topk_weights = topk_weights * self.routed_scaling_factor
  156. return topk_indices, topk_weights
  157. class Dots1DecoderLayer(DeepseekV3DecoderLayer):
  158. pass
  159. class Dots1PreTrainedModel(DeepseekV3PreTrainedModel):
  160. _keys_to_ignore_on_load_unexpected = None
  161. class Dots1Model(Qwen3Model):
  162. pass
  163. class Dots1ForCausalLM(Qwen3ForCausalLM):
  164. def forward(
  165. self,
  166. **super_kwargs: Unpack[TransformersKwargs],
  167. ) -> CausalLMOutputWithPast:
  168. r"""
  169. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  170. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  171. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  172. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  173. Example:
  174. ```python
  175. >>> from transformers import AutoTokenizer, Dots1ForCausalLM
  176. >>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst")
  177. >>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst")
  178. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  179. >>> inputs = tokenizer(prompt, return_tensors="pt")
  180. >>> # Generate
  181. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  182. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  183. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  184. ```"""
  185. return super().forward(**super_kwargs)
  186. __all__ = [
  187. "Dots1Config",
  188. "Dots1PreTrainedModel",
  189. "Dots1Model",
  190. "Dots1ForCausalLM",
  191. ]