modeling_starcoder2.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/starcoder2/modular_starcoder2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_starcoder2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. from collections.abc import Callable
  26. from typing import Optional
  27. import torch
  28. from torch import nn
  29. from ...activations import ACT2FN
  30. from ...cache_utils import Cache, DynamicCache
  31. from ...generation import GenerationMixin
  32. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  33. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  34. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  35. from ...modeling_layers import (
  36. GenericForSequenceClassification,
  37. GenericForTokenClassification,
  38. GradientCheckpointingLayer,
  39. )
  40. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  41. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  42. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  45. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  46. from ...utils.output_capturing import capture_outputs
  47. from .configuration_starcoder2 import Starcoder2Config
  48. class Starcoder2MLP(nn.Module):
  49. def __init__(self, config: Starcoder2Config):
  50. super().__init__()
  51. embed_dim = config.hidden_size
  52. self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
  53. self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
  54. self.act = ACT2FN[config.hidden_act]
  55. self.residual_dropout = config.residual_dropout
  56. def forward(self, hidden_states: tuple[torch.FloatTensor] | None) -> torch.FloatTensor:
  57. hidden_states = self.c_fc(hidden_states)
  58. hidden_states = self.act(hidden_states)
  59. hidden_states = self.c_proj(hidden_states)
  60. hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
  61. return hidden_states
  62. def rotate_half(x):
  63. """Rotates half the hidden dims of the input."""
  64. x1 = x[..., : x.shape[-1] // 2]
  65. x2 = x[..., x.shape[-1] // 2 :]
  66. return torch.cat((-x2, x1), dim=-1)
  67. @use_kernel_func_from_hub("rotary_pos_emb")
  68. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  69. """Applies Rotary Position Embedding to the query and key tensors.
  70. Args:
  71. q (`torch.Tensor`): The query tensor.
  72. k (`torch.Tensor`): The key tensor.
  73. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  74. sin (`torch.Tensor`): The sine part of the rotary embedding.
  75. unsqueeze_dim (`int`, *optional*, defaults to 1):
  76. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  77. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  78. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  79. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  80. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  81. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  82. Returns:
  83. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  84. """
  85. cos = cos.unsqueeze(unsqueeze_dim)
  86. sin = sin.unsqueeze(unsqueeze_dim)
  87. q_embed = (q * cos) + (rotate_half(q) * sin)
  88. k_embed = (k * cos) + (rotate_half(k) * sin)
  89. return q_embed, k_embed
  90. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  91. """
  92. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  93. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  94. """
  95. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  96. if n_rep == 1:
  97. return hidden_states
  98. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  99. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  100. def eager_attention_forward(
  101. module: nn.Module,
  102. query: torch.Tensor,
  103. key: torch.Tensor,
  104. value: torch.Tensor,
  105. attention_mask: torch.Tensor | None,
  106. scaling: float,
  107. dropout: float = 0.0,
  108. **kwargs: Unpack[TransformersKwargs],
  109. ):
  110. key_states = repeat_kv(key, module.num_key_value_groups)
  111. value_states = repeat_kv(value, module.num_key_value_groups)
  112. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  113. if attention_mask is not None:
  114. attn_weights = attn_weights + attention_mask
  115. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  116. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  117. attn_output = torch.matmul(attn_weights, value_states)
  118. attn_output = attn_output.transpose(1, 2).contiguous()
  119. return attn_output, attn_weights
  120. @use_kernelized_func(apply_rotary_pos_emb)
  121. class Starcoder2Attention(nn.Module):
  122. """Multi-headed attention from 'Attention Is All You Need' paper"""
  123. def __init__(self, config: Starcoder2Config, layer_idx: int | None = None):
  124. super().__init__()
  125. self.config = config
  126. self.layer_idx = layer_idx
  127. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  128. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  129. self.scaling = self.head_dim**-0.5
  130. self.attention_dropout = config.attention_dropout
  131. self.is_causal = True
  132. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias)
  133. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
  134. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
  135. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
  136. self.residual_dropout = config.residual_dropout
  137. def forward(
  138. self,
  139. hidden_states: torch.Tensor,
  140. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  141. attention_mask: torch.Tensor | None,
  142. past_key_values: Cache | None = None,
  143. **kwargs: Unpack[FlashAttentionKwargs],
  144. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  145. input_shape = hidden_states.shape[:-1]
  146. hidden_shape = (*input_shape, -1, self.head_dim)
  147. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  148. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  149. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  150. cos, sin = position_embeddings
  151. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  152. if past_key_values is not None:
  153. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  154. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  155. self.config._attn_implementation, eager_attention_forward
  156. )
  157. attn_output, attn_weights = attention_interface(
  158. self,
  159. query_states,
  160. key_states,
  161. value_states,
  162. attention_mask,
  163. dropout=0.0 if not self.training else self.attention_dropout,
  164. scaling=self.scaling,
  165. sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama
  166. **kwargs,
  167. )
  168. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  169. attn_output = self.o_proj(attn_output)
  170. attn_output = nn.functional.dropout(
  171. attn_output, p=self.residual_dropout, training=self.training
  172. ) # diff with Llama
  173. return attn_output, attn_weights
  174. class Starcoder2DecoderLayer(GradientCheckpointingLayer):
  175. def __init__(self, config: Starcoder2Config, layer_idx: int):
  176. super().__init__()
  177. self.hidden_size = config.hidden_size
  178. self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx)
  179. self.mlp = Starcoder2MLP(config)
  180. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  181. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  182. def forward(
  183. self,
  184. hidden_states: torch.Tensor,
  185. attention_mask: torch.Tensor | None = None,
  186. position_ids: torch.LongTensor | None = None,
  187. past_key_values: Cache | None = None,
  188. use_cache: bool | None = False,
  189. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  190. **kwargs: Unpack[TransformersKwargs],
  191. ) -> torch.Tensor:
  192. residual = hidden_states
  193. hidden_states = self.input_layernorm(hidden_states)
  194. # Self Attention
  195. hidden_states, _ = self.self_attn(
  196. hidden_states=hidden_states,
  197. attention_mask=attention_mask,
  198. position_ids=position_ids,
  199. past_key_values=past_key_values,
  200. use_cache=use_cache,
  201. position_embeddings=position_embeddings,
  202. **kwargs,
  203. )
  204. hidden_states = residual + hidden_states
  205. # Fully Connected
  206. residual = hidden_states
  207. hidden_states = self.post_attention_layernorm(hidden_states)
  208. hidden_states = self.mlp(hidden_states)
  209. hidden_states = residual + hidden_states
  210. return hidden_states
  211. @auto_docstring
  212. class Starcoder2PreTrainedModel(PreTrainedModel):
  213. config: Starcoder2Config
  214. base_model_prefix = "model"
  215. supports_gradient_checkpointing = True
  216. _no_split_modules = ["Starcoder2DecoderLayer"]
  217. _skip_keys_device_placement = ["past_key_values"]
  218. _supports_flash_attn = True
  219. _supports_sdpa = True
  220. _supports_flex_attn = True
  221. _can_compile_fullgraph = True
  222. _supports_attention_backend = True
  223. _can_record_outputs = {
  224. "hidden_states": Starcoder2DecoderLayer,
  225. "attentions": Starcoder2Attention,
  226. }
  227. class Starcoder2RotaryEmbedding(nn.Module):
  228. inv_freq: torch.Tensor # fix linting for `register_buffer`
  229. def __init__(self, config: Starcoder2Config, device=None):
  230. super().__init__()
  231. self.max_seq_len_cached = config.max_position_embeddings
  232. self.original_max_seq_len = config.max_position_embeddings
  233. self.config = config
  234. self.rope_type = self.config.rope_parameters["rope_type"]
  235. rope_init_fn: Callable = self.compute_default_rope_parameters
  236. if self.rope_type != "default":
  237. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  238. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  239. self.register_buffer("inv_freq", inv_freq, persistent=False)
  240. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  241. @staticmethod
  242. def compute_default_rope_parameters(
  243. config: Starcoder2Config | None = None,
  244. device: Optional["torch.device"] = None,
  245. seq_len: int | None = None,
  246. ) -> tuple["torch.Tensor", float]:
  247. """
  248. Computes the inverse frequencies according to the original RoPE implementation
  249. Args:
  250. config ([`~transformers.PreTrainedConfig`]):
  251. The model configuration.
  252. device (`torch.device`):
  253. The device to use for initialization of the inverse frequencies.
  254. seq_len (`int`, *optional*):
  255. The current sequence length. Unused for this type of RoPE.
  256. Returns:
  257. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  258. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  259. """
  260. base = config.rope_parameters["rope_theta"]
  261. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  262. attention_factor = 1.0 # Unused in this type of RoPE
  263. # Compute the inverse frequencies
  264. inv_freq = 1.0 / (
  265. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  266. )
  267. return inv_freq, attention_factor
  268. @torch.no_grad()
  269. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  270. def forward(self, x, position_ids):
  271. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  272. position_ids_expanded = position_ids[:, None, :].float()
  273. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  274. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  275. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  276. emb = torch.cat((freqs, freqs), dim=-1)
  277. cos = emb.cos() * self.attention_scaling
  278. sin = emb.sin() * self.attention_scaling
  279. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  280. @auto_docstring
  281. class Starcoder2Model(Starcoder2PreTrainedModel):
  282. def __init__(self, config: Starcoder2Config):
  283. super().__init__(config)
  284. self.padding_idx = config.pad_token_id
  285. self.vocab_size = config.vocab_size
  286. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  287. self.layers = nn.ModuleList(
  288. [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  289. )
  290. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  291. self.rotary_emb = Starcoder2RotaryEmbedding(config=config)
  292. self.gradient_checkpointing = False
  293. self.embedding_dropout = config.embedding_dropout
  294. # Initialize weights and apply final processing
  295. self.post_init()
  296. @merge_with_config_defaults
  297. @capture_outputs
  298. def forward(
  299. self,
  300. input_ids: torch.LongTensor | None = None,
  301. attention_mask: torch.Tensor | None = None,
  302. position_ids: torch.LongTensor | None = None,
  303. past_key_values: Cache | None = None,
  304. inputs_embeds: torch.FloatTensor | None = None,
  305. use_cache: bool | None = None,
  306. **kwargs: Unpack[TransformersKwargs],
  307. ) -> tuple | BaseModelOutputWithPast:
  308. if (input_ids is None) ^ (inputs_embeds is not None):
  309. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  310. if inputs_embeds is None:
  311. inputs_embeds = self.embed_tokens(input_ids)
  312. if use_cache and past_key_values is None:
  313. past_key_values = DynamicCache(config=self.config)
  314. if position_ids is None:
  315. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  316. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  317. position_ids = position_ids.unsqueeze(0)
  318. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  319. causal_mask = mask_function(
  320. config=self.config,
  321. inputs_embeds=inputs_embeds,
  322. attention_mask=attention_mask,
  323. past_key_values=past_key_values,
  324. position_ids=position_ids,
  325. )
  326. hidden_states = inputs_embeds
  327. hidden_states = nn.functional.dropout(
  328. hidden_states, p=self.embedding_dropout, training=self.training
  329. ) # main diff with Llama
  330. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  331. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  332. hidden_states = decoder_layer(
  333. hidden_states,
  334. attention_mask=causal_mask,
  335. position_ids=position_ids,
  336. past_key_values=past_key_values,
  337. use_cache=use_cache,
  338. position_embeddings=position_embeddings,
  339. **kwargs,
  340. )
  341. hidden_states = self.norm(hidden_states)
  342. return BaseModelOutputWithPast(
  343. last_hidden_state=hidden_states,
  344. past_key_values=past_key_values if use_cache else None,
  345. )
  346. @auto_docstring
  347. class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
  348. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  349. _tp_plan = {"lm_head": "colwise_gather_output"}
  350. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  351. def __init__(self, config):
  352. super().__init__(config)
  353. self.model = Starcoder2Model(config)
  354. self.vocab_size = config.vocab_size
  355. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  356. # Initialize weights and apply final processing
  357. self.post_init()
  358. @can_return_tuple
  359. @auto_docstring
  360. def forward(
  361. self,
  362. input_ids: torch.LongTensor | None = None,
  363. attention_mask: torch.Tensor | None = None,
  364. position_ids: torch.LongTensor | None = None,
  365. past_key_values: Cache | None = None,
  366. inputs_embeds: torch.FloatTensor | None = None,
  367. labels: torch.LongTensor | None = None,
  368. use_cache: bool | None = None,
  369. logits_to_keep: int | torch.Tensor = 0,
  370. **kwargs: Unpack[TransformersKwargs],
  371. ) -> CausalLMOutputWithPast:
  372. r"""
  373. Example:
  374. ```python
  375. >>> from transformers import AutoTokenizer, Starcoder2ForCausalLM
  376. >>> model = Starcoder2ForCausalLM.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf")
  377. >>> tokenizer = AutoTokenizer.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf")
  378. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  379. >>> inputs = tokenizer(prompt, return_tensors="pt")
  380. >>> # Generate
  381. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  382. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  383. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  384. ```"""
  385. outputs: BaseModelOutputWithPast = self.model(
  386. input_ids=input_ids,
  387. attention_mask=attention_mask,
  388. position_ids=position_ids,
  389. past_key_values=past_key_values,
  390. inputs_embeds=inputs_embeds,
  391. use_cache=use_cache,
  392. **kwargs,
  393. )
  394. hidden_states = outputs.last_hidden_state
  395. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  396. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  397. logits = self.lm_head(hidden_states[:, slice_indices, :])
  398. loss = None
  399. if labels is not None:
  400. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  401. return CausalLMOutputWithPast(
  402. loss=loss,
  403. logits=logits,
  404. past_key_values=outputs.past_key_values,
  405. hidden_states=outputs.hidden_states,
  406. attentions=outputs.attentions,
  407. )
  408. class Starcoder2ForSequenceClassification(GenericForSequenceClassification, Starcoder2PreTrainedModel):
  409. pass
  410. class Starcoder2ForTokenClassification(GenericForTokenClassification, Starcoder2PreTrainedModel):
  411. pass
  412. __all__ = [
  413. "Starcoder2ForCausalLM",
  414. "Starcoder2Model",
  415. "Starcoder2PreTrainedModel",
  416. "Starcoder2ForSequenceClassification",
  417. "Starcoder2ForTokenClassification",
  418. ]