modeling_olmoe.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/olmoe/modular_olmoe.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_olmoe.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. from collections.abc import Callable
  19. from typing import Optional
  20. import torch
  21. import torch.nn.functional as F
  22. from torch import nn
  23. from ... import initialization as init
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...integrations import (
  28. use_experts_implementation,
  29. use_kernel_forward_from_hub,
  30. use_kernel_func_from_hub,
  31. use_kernelized_func,
  32. )
  33. from ...masking_utils import create_causal_mask
  34. from ...modeling_layers import GradientCheckpointingLayer
  35. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  36. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  40. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  41. from ...utils.output_capturing import OutputRecorder, capture_outputs
  42. from .configuration_olmoe import OlmoeConfig
  43. @use_kernel_forward_from_hub("RMSNorm")
  44. class OlmoeRMSNorm(nn.Module):
  45. def __init__(self, hidden_size, eps=1e-5) -> None:
  46. """
  47. OlmoeRMSNorm is equivalent to T5LayerNorm
  48. """
  49. super().__init__()
  50. self.weight = nn.Parameter(torch.ones(hidden_size))
  51. self.variance_epsilon = eps
  52. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  53. input_dtype = hidden_states.dtype
  54. hidden_states = hidden_states.to(torch.float32)
  55. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  56. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  57. return self.weight * hidden_states.to(input_dtype)
  58. def extra_repr(self):
  59. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  60. class OlmoeRotaryEmbedding(nn.Module):
  61. inv_freq: torch.Tensor # fix linting for `register_buffer`
  62. def __init__(self, config: OlmoeConfig, device=None):
  63. super().__init__()
  64. self.max_seq_len_cached = config.max_position_embeddings
  65. self.original_max_seq_len = config.max_position_embeddings
  66. self.config = config
  67. self.rope_type = self.config.rope_parameters["rope_type"]
  68. rope_init_fn: Callable = self.compute_default_rope_parameters
  69. if self.rope_type != "default":
  70. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  71. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  72. self.register_buffer("inv_freq", inv_freq, persistent=False)
  73. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  74. @staticmethod
  75. def compute_default_rope_parameters(
  76. config: OlmoeConfig | None = None,
  77. device: Optional["torch.device"] = None,
  78. seq_len: int | None = None,
  79. ) -> tuple["torch.Tensor", float]:
  80. """
  81. Computes the inverse frequencies according to the original RoPE implementation
  82. Args:
  83. config ([`~transformers.PreTrainedConfig`]):
  84. The model configuration.
  85. device (`torch.device`):
  86. The device to use for initialization of the inverse frequencies.
  87. seq_len (`int`, *optional*):
  88. The current sequence length. Unused for this type of RoPE.
  89. Returns:
  90. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  91. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  92. """
  93. base = config.rope_parameters["rope_theta"]
  94. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  95. attention_factor = 1.0 # Unused in this type of RoPE
  96. # Compute the inverse frequencies
  97. inv_freq = 1.0 / (
  98. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  99. )
  100. return inv_freq, attention_factor
  101. @torch.no_grad()
  102. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  103. def forward(self, x, position_ids):
  104. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  105. position_ids_expanded = position_ids[:, None, :].float()
  106. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  107. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  108. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  109. emb = torch.cat((freqs, freqs), dim=-1)
  110. cos = emb.cos() * self.attention_scaling
  111. sin = emb.sin() * self.attention_scaling
  112. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  113. class OlmoeMLP(nn.Module):
  114. def __init__(self, config):
  115. super().__init__()
  116. self.config = config
  117. self.hidden_size = config.hidden_size
  118. self.intermediate_size = config.intermediate_size
  119. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  120. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  121. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  122. self.act_fn = ACT2FN[config.hidden_act]
  123. def forward(self, x):
  124. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  125. return down_proj
  126. def rotate_half(x):
  127. """Rotates half the hidden dims of the input."""
  128. x1 = x[..., : x.shape[-1] // 2]
  129. x2 = x[..., x.shape[-1] // 2 :]
  130. return torch.cat((-x2, x1), dim=-1)
  131. @use_kernel_func_from_hub("rotary_pos_emb")
  132. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  133. """Applies Rotary Position Embedding to the query and key tensors.
  134. Args:
  135. q (`torch.Tensor`): The query tensor.
  136. k (`torch.Tensor`): The key tensor.
  137. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  138. sin (`torch.Tensor`): The sine part of the rotary embedding.
  139. unsqueeze_dim (`int`, *optional*, defaults to 1):
  140. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  141. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  142. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  143. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  144. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  145. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  146. Returns:
  147. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  148. """
  149. cos = cos.unsqueeze(unsqueeze_dim)
  150. sin = sin.unsqueeze(unsqueeze_dim)
  151. q_embed = (q * cos) + (rotate_half(q) * sin)
  152. k_embed = (k * cos) + (rotate_half(k) * sin)
  153. return q_embed, k_embed
  154. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  155. """
  156. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  157. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  158. """
  159. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  160. if n_rep == 1:
  161. return hidden_states
  162. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  163. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  164. def eager_attention_forward(
  165. module: nn.Module,
  166. query: torch.Tensor,
  167. key: torch.Tensor,
  168. value: torch.Tensor,
  169. attention_mask: torch.Tensor | None,
  170. scaling: float,
  171. dropout: float = 0.0,
  172. **kwargs: Unpack[TransformersKwargs],
  173. ):
  174. key_states = repeat_kv(key, module.num_key_value_groups)
  175. value_states = repeat_kv(value, module.num_key_value_groups)
  176. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  177. if attention_mask is not None:
  178. attn_weights = attn_weights + attention_mask
  179. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  180. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  181. attn_output = torch.matmul(attn_weights, value_states)
  182. attn_output = attn_output.transpose(1, 2).contiguous()
  183. return attn_output, attn_weights
  184. @use_kernelized_func(apply_rotary_pos_emb)
  185. class OlmoeAttention(nn.Module):
  186. """Multi-headed attention from 'Attention Is All You Need' paper"""
  187. def __init__(self, config: OlmoeConfig, layer_idx: int | None = None):
  188. super().__init__()
  189. self.config = config
  190. self.layer_idx = layer_idx
  191. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  192. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  193. self.scaling = self.head_dim**-0.5
  194. self.attention_dropout = config.attention_dropout
  195. self.is_causal = True
  196. self.q_proj = nn.Linear(
  197. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  198. )
  199. self.k_proj = nn.Linear(
  200. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  201. )
  202. self.v_proj = nn.Linear(
  203. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  204. )
  205. self.o_proj = nn.Linear(
  206. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  207. )
  208. self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  209. self.k_norm = OlmoeRMSNorm(
  210. (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps
  211. )
  212. def forward(
  213. self,
  214. hidden_states: torch.Tensor,
  215. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  216. attention_mask: torch.Tensor | None,
  217. past_key_values: Cache | None = None,
  218. **kwargs: Unpack[TransformersKwargs],
  219. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  220. input_shape = hidden_states.shape[:-1]
  221. hidden_shape = (*input_shape, -1, self.head_dim)
  222. query_states = self.q_norm(self.q_proj(hidden_states))
  223. key_states = self.k_norm(self.k_proj(hidden_states))
  224. value_states = self.v_proj(hidden_states)
  225. if self.config.clip_qkv is not None: # Diff with llama
  226. query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  227. key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  228. value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  229. query_states = query_states.view(*hidden_shape).transpose(1, 2)
  230. key_states = key_states.view(*hidden_shape).transpose(1, 2)
  231. value_states = value_states.view(*hidden_shape).transpose(1, 2)
  232. cos, sin = position_embeddings
  233. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  234. if past_key_values is not None:
  235. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  236. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  237. self.config._attn_implementation, eager_attention_forward
  238. )
  239. attn_output, attn_weights = attention_interface(
  240. self,
  241. query_states,
  242. key_states,
  243. value_states,
  244. attention_mask,
  245. dropout=0.0 if not self.training else self.attention_dropout,
  246. scaling=self.scaling,
  247. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  248. **kwargs,
  249. )
  250. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  251. attn_output = self.o_proj(attn_output)
  252. return attn_output, attn_weights
  253. @use_experts_implementation
  254. class OlmoeExperts(nn.Module):
  255. """Collection of expert weights stored as 3D tensors."""
  256. def __init__(self, config: OlmoeConfig):
  257. super().__init__()
  258. self.num_experts = config.num_local_experts
  259. self.hidden_dim = config.hidden_size
  260. self.intermediate_dim = config.intermediate_size
  261. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  262. self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  263. self.act_fn = ACT2FN[config.hidden_act]
  264. def forward(
  265. self,
  266. hidden_states: torch.Tensor,
  267. top_k_index: torch.Tensor,
  268. top_k_weights: torch.Tensor,
  269. ) -> torch.Tensor:
  270. final_hidden_states = torch.zeros_like(hidden_states)
  271. with torch.no_grad():
  272. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  273. expert_mask = expert_mask.permute(2, 1, 0)
  274. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  275. for expert_idx in expert_hit:
  276. expert_idx = expert_idx[0]
  277. if expert_idx == self.num_experts:
  278. continue
  279. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  280. current_state = hidden_states[token_idx]
  281. gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  282. current_hidden_states = self.act_fn(gate) * up
  283. current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  284. current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  285. final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  286. return final_hidden_states
  287. class OlmoeTopKRouter(nn.Module):
  288. def __init__(self, config):
  289. super().__init__()
  290. self.top_k = config.num_experts_per_tok
  291. self.num_experts = config.num_experts
  292. self.norm_topk_prob = config.norm_topk_prob
  293. self.hidden_dim = config.hidden_size
  294. self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
  295. def forward(self, hidden_states):
  296. hidden_states = hidden_states.reshape(-1, self.hidden_dim)
  297. router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
  298. router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
  299. router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
  300. if self.norm_topk_prob:
  301. router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
  302. router_top_value = router_top_value.to(router_logits.dtype)
  303. router_scores = router_top_value
  304. return router_logits, router_scores, router_indices
  305. class OlmoeSparseMoeBlock(nn.Module):
  306. def __init__(self, config):
  307. super().__init__()
  308. self.gate = OlmoeTopKRouter(config)
  309. self.experts = OlmoeExperts(config)
  310. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  311. batch_size, sequence_length, hidden_dim = hidden_states.shape
  312. hidden_states = hidden_states.view(-1, hidden_dim)
  313. _, top_k_weights, top_k_index = self.gate(hidden_states)
  314. final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
  315. batch_size, sequence_length, hidden_dim
  316. )
  317. return final_hidden_states
  318. class OlmoeDecoderLayer(GradientCheckpointingLayer):
  319. def __init__(self, config: OlmoeConfig, layer_idx: int):
  320. super().__init__()
  321. self.hidden_size = config.hidden_size
  322. self.self_attn = OlmoeAttention(config=config, layer_idx=layer_idx)
  323. self.mlp = OlmoeSparseMoeBlock(config)
  324. self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  325. self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  326. def forward(
  327. self,
  328. hidden_states: torch.Tensor,
  329. attention_mask: torch.Tensor | None = None,
  330. position_ids: torch.LongTensor | None = None,
  331. past_key_values: Cache | None = None,
  332. use_cache: bool | None = False,
  333. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  334. **kwargs: Unpack[TransformersKwargs],
  335. ) -> torch.Tensor:
  336. residual = hidden_states
  337. hidden_states = self.input_layernorm(hidden_states)
  338. # Self Attention
  339. hidden_states, _ = self.self_attn(
  340. hidden_states=hidden_states,
  341. attention_mask=attention_mask,
  342. position_ids=position_ids,
  343. past_key_values=past_key_values,
  344. use_cache=use_cache,
  345. position_embeddings=position_embeddings,
  346. **kwargs,
  347. )
  348. hidden_states = residual + hidden_states
  349. # Fully Connected
  350. residual = hidden_states
  351. hidden_states = self.post_attention_layernorm(hidden_states)
  352. hidden_states = self.mlp(hidden_states)
  353. hidden_states = residual + hidden_states
  354. return hidden_states
  355. @auto_docstring
  356. class OlmoePreTrainedModel(PreTrainedModel):
  357. config: OlmoeConfig
  358. base_model_prefix = "model"
  359. supports_gradient_checkpointing = True
  360. _no_split_modules = ["OlmoeDecoderLayer"]
  361. _skip_keys_device_placement = ["past_key_values"]
  362. _supports_flash_attn = True
  363. _supports_sdpa = True
  364. _can_record_outputs = {
  365. "router_logits": OutputRecorder(OlmoeTopKRouter, index=0),
  366. "hidden_states": OlmoeDecoderLayer,
  367. "attentions": OlmoeAttention,
  368. }
  369. _supports_attention_backend = True
  370. @torch.no_grad()
  371. def _init_weights(self, module):
  372. PreTrainedModel._init_weights(self, module)
  373. if isinstance(module, OlmoeExperts):
  374. init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
  375. init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
  376. elif isinstance(module, OlmoeTopKRouter):
  377. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  378. @auto_docstring
  379. class OlmoeModel(OlmoePreTrainedModel):
  380. def __init__(self, config: OlmoeConfig):
  381. super().__init__(config)
  382. self.padding_idx = config.pad_token_id
  383. self.vocab_size = config.vocab_size
  384. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  385. self.layers = nn.ModuleList(
  386. [OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  387. )
  388. self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  389. self.rotary_emb = OlmoeRotaryEmbedding(config=config)
  390. self.gradient_checkpointing = False
  391. # Initialize weights and apply final processing
  392. self.post_init()
  393. @merge_with_config_defaults
  394. @capture_outputs
  395. @auto_docstring
  396. def forward(
  397. self,
  398. input_ids: torch.LongTensor | None = None,
  399. attention_mask: torch.Tensor | None = None,
  400. position_ids: torch.LongTensor | None = None,
  401. past_key_values: Cache | None = None,
  402. inputs_embeds: torch.FloatTensor | None = None,
  403. use_cache: bool | None = None,
  404. **kwargs: Unpack[TransformersKwargs],
  405. ) -> MoeModelOutputWithPast:
  406. if (input_ids is None) ^ (inputs_embeds is not None):
  407. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  408. if use_cache and past_key_values is None:
  409. past_key_values = DynamicCache(config=self.config)
  410. if inputs_embeds is None:
  411. inputs_embeds = self.embed_tokens(input_ids)
  412. if position_ids is None:
  413. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  414. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  415. position_ids = position_ids.unsqueeze(0)
  416. causal_mask = create_causal_mask( # diff with mixtral: no sliding
  417. config=self.config,
  418. inputs_embeds=inputs_embeds,
  419. attention_mask=attention_mask,
  420. past_key_values=past_key_values,
  421. position_ids=position_ids,
  422. )
  423. hidden_states = inputs_embeds
  424. # create position embeddings to be shared across the decoder layers
  425. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  426. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  427. hidden_states = decoder_layer(
  428. hidden_states,
  429. position_embeddings=position_embeddings,
  430. attention_mask=causal_mask,
  431. position_ids=position_ids,
  432. past_key_values=past_key_values,
  433. use_cache=use_cache,
  434. **kwargs,
  435. )
  436. hidden_states = self.norm(hidden_states)
  437. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  438. last_hidden_state=hidden_states,
  439. past_key_values=past_key_values,
  440. )
  441. def load_balancing_loss_func(
  442. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  443. num_experts: int | None = None,
  444. top_k=2,
  445. attention_mask: torch.Tensor | None = None,
  446. ) -> torch.Tensor | int:
  447. r"""
  448. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  449. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  450. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  451. experts is too unbalanced.
  452. Args:
  453. gate_logits:
  454. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  455. shape [batch_size X sequence_length, num_experts].
  456. num_experts:
  457. Number of experts
  458. top_k:
  459. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  460. parameter.
  461. attention_mask (`torch.Tensor`, *optional*):
  462. The attention_mask used in forward function
  463. shape [batch_size X sequence_length] if not None.
  464. Returns:
  465. The auxiliary loss.
  466. """
  467. if gate_logits is None or not isinstance(gate_logits, tuple):
  468. return 0
  469. if isinstance(gate_logits, tuple):
  470. compute_device = gate_logits[0].device
  471. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  472. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  473. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  474. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  475. if attention_mask is None:
  476. # Compute the percentage of tokens routed to each experts
  477. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  478. # Compute the average probability of routing to these experts
  479. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  480. else:
  481. batch_size, sequence_length = attention_mask.shape
  482. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  483. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  484. expert_attention_mask = (
  485. attention_mask[None, :, :, None, None]
  486. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  487. .reshape(-1, top_k, num_experts)
  488. .to(compute_device)
  489. )
  490. # Compute the percentage of tokens routed to each experts
  491. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  492. expert_attention_mask, dim=0
  493. )
  494. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  495. router_per_expert_attention_mask = (
  496. attention_mask[None, :, :, None]
  497. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  498. .reshape(-1, num_experts)
  499. .to(compute_device)
  500. )
  501. # Compute the average probability of routing to these experts
  502. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  503. router_per_expert_attention_mask, dim=0
  504. )
  505. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  506. return overall_loss * num_experts
  507. @auto_docstring
  508. class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
  509. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  510. _tp_plan = {"lm_head": "colwise_gather_output"}
  511. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  512. def __init__(self, config):
  513. super().__init__(config)
  514. self.model = OlmoeModel(config)
  515. self.vocab_size = config.vocab_size
  516. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  517. self.router_aux_loss_coef = config.router_aux_loss_coef
  518. self.num_experts = config.num_experts
  519. self.num_experts_per_tok = config.num_experts_per_tok
  520. # Initialize weights and apply final processing
  521. self.post_init()
  522. @can_return_tuple
  523. @auto_docstring
  524. def forward(
  525. self,
  526. input_ids: torch.LongTensor | None = None,
  527. attention_mask: torch.Tensor | None = None,
  528. position_ids: torch.LongTensor | None = None,
  529. past_key_values: Cache | None = None,
  530. inputs_embeds: torch.FloatTensor | None = None,
  531. labels: torch.LongTensor | None = None,
  532. use_cache: bool | None = None,
  533. output_router_logits: bool | None = None,
  534. logits_to_keep: int | torch.Tensor = 0,
  535. **kwargs: Unpack[TransformersKwargs],
  536. ) -> MoeCausalLMOutputWithPast:
  537. r"""
  538. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  539. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  540. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  541. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  542. Example:
  543. ```python
  544. >>> from transformers import AutoTokenizer, OlmoeForCausalLM
  545. >>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
  546. >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
  547. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  548. >>> inputs = tokenizer(prompt, return_tensors="pt")
  549. >>> # Generate
  550. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  551. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  552. 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
  553. ```
  554. """
  555. output_router_logits = (
  556. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  557. )
  558. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  559. outputs: MoeModelOutputWithPast = self.model(
  560. input_ids=input_ids,
  561. attention_mask=attention_mask,
  562. position_ids=position_ids,
  563. past_key_values=past_key_values,
  564. inputs_embeds=inputs_embeds,
  565. use_cache=use_cache,
  566. output_router_logits=output_router_logits,
  567. **kwargs,
  568. )
  569. hidden_states = outputs.last_hidden_state
  570. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  571. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  572. logits = self.lm_head(hidden_states[:, slice_indices, :])
  573. loss = None
  574. if labels is not None:
  575. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  576. aux_loss = None
  577. if output_router_logits:
  578. aux_loss = load_balancing_loss_func(
  579. outputs.router_logits,
  580. self.num_experts,
  581. self.num_experts_per_tok,
  582. attention_mask,
  583. )
  584. if labels is not None:
  585. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  586. return MoeCausalLMOutputWithPast(
  587. loss=loss,
  588. aux_loss=aux_loss,
  589. logits=logits,
  590. past_key_values=outputs.past_key_values,
  591. hidden_states=outputs.hidden_states,
  592. attentions=outputs.attentions,
  593. router_logits=outputs.router_logits,
  594. )
  595. __all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]