modular_starcoder2.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """PyTorch Starcoder2 model."""
  20. from collections.abc import Callable
  21. import torch
  22. from torch import nn
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache
  25. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  26. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  27. from ...modeling_outputs import BaseModelOutputWithPast
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  29. from ...processing_utils import Unpack
  30. from ...utils import TransformersKwargs, logging
  31. from ...utils.generic import merge_with_config_defaults
  32. from ...utils.output_capturing import capture_outputs
  33. from ..mistral.modeling_mistral import (
  34. MistralAttention,
  35. MistralDecoderLayer,
  36. MistralForCausalLM,
  37. MistralForSequenceClassification,
  38. MistralForTokenClassification,
  39. MistralModel,
  40. apply_rotary_pos_emb,
  41. eager_attention_forward,
  42. )
  43. from .configuration_starcoder2 import Starcoder2Config
  44. logger = logging.get_logger(__name__)
  45. class Starcoder2MLP(nn.Module):
  46. def __init__(self, config: Starcoder2Config):
  47. super().__init__()
  48. embed_dim = config.hidden_size
  49. self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
  50. self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
  51. self.act = ACT2FN[config.hidden_act]
  52. self.residual_dropout = config.residual_dropout
  53. def forward(self, hidden_states: tuple[torch.FloatTensor] | None) -> torch.FloatTensor:
  54. hidden_states = self.c_fc(hidden_states)
  55. hidden_states = self.act(hidden_states)
  56. hidden_states = self.c_proj(hidden_states)
  57. hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
  58. return hidden_states
  59. class Starcoder2Attention(MistralAttention):
  60. def __init__(self, config: Starcoder2Config, layer_idx: int | None = None):
  61. super().__init__(config=config, layer_idx=layer_idx)
  62. self.residual_dropout = config.residual_dropout
  63. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias)
  64. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
  65. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
  66. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
  67. def forward(
  68. self,
  69. hidden_states: torch.Tensor,
  70. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  71. attention_mask: torch.Tensor | None,
  72. past_key_values: Cache | None = None,
  73. **kwargs: Unpack[FlashAttentionKwargs],
  74. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  75. input_shape = hidden_states.shape[:-1]
  76. hidden_shape = (*input_shape, -1, self.head_dim)
  77. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  78. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  79. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  80. cos, sin = position_embeddings
  81. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  82. if past_key_values is not None:
  83. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  84. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  85. self.config._attn_implementation, eager_attention_forward
  86. )
  87. attn_output, attn_weights = attention_interface(
  88. self,
  89. query_states,
  90. key_states,
  91. value_states,
  92. attention_mask,
  93. dropout=0.0 if not self.training else self.attention_dropout,
  94. scaling=self.scaling,
  95. sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama
  96. **kwargs,
  97. )
  98. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  99. attn_output = self.o_proj(attn_output)
  100. attn_output = nn.functional.dropout(
  101. attn_output, p=self.residual_dropout, training=self.training
  102. ) # diff with Llama
  103. return attn_output, attn_weights
  104. class Starcoder2DecoderLayer(MistralDecoderLayer):
  105. def __init__(self, config: Starcoder2Config, layer_idx: int):
  106. super().__init__(config, layer_idx)
  107. self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx)
  108. self.mlp = Starcoder2MLP(config)
  109. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  110. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  111. class Starcoder2Model(MistralModel):
  112. def __init__(self, config: Starcoder2Config):
  113. super().__init__(config)
  114. self.layers = nn.ModuleList(
  115. [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  116. )
  117. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  118. self.embedding_dropout = config.embedding_dropout
  119. @merge_with_config_defaults
  120. @capture_outputs
  121. def forward(
  122. self,
  123. input_ids: torch.LongTensor | None = None,
  124. attention_mask: torch.Tensor | None = None,
  125. position_ids: torch.LongTensor | None = None,
  126. past_key_values: Cache | None = None,
  127. inputs_embeds: torch.FloatTensor | None = None,
  128. use_cache: bool | None = None,
  129. **kwargs: Unpack[TransformersKwargs],
  130. ) -> tuple | BaseModelOutputWithPast:
  131. if (input_ids is None) ^ (inputs_embeds is not None):
  132. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  133. if inputs_embeds is None:
  134. inputs_embeds = self.embed_tokens(input_ids)
  135. if use_cache and past_key_values is None:
  136. past_key_values = DynamicCache(config=self.config)
  137. if position_ids is None:
  138. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  139. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  140. position_ids = position_ids.unsqueeze(0)
  141. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  142. causal_mask = mask_function(
  143. config=self.config,
  144. inputs_embeds=inputs_embeds,
  145. attention_mask=attention_mask,
  146. past_key_values=past_key_values,
  147. position_ids=position_ids,
  148. )
  149. hidden_states = inputs_embeds
  150. hidden_states = nn.functional.dropout(
  151. hidden_states, p=self.embedding_dropout, training=self.training
  152. ) # main diff with Llama
  153. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  154. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  155. hidden_states = decoder_layer(
  156. hidden_states,
  157. attention_mask=causal_mask,
  158. position_ids=position_ids,
  159. past_key_values=past_key_values,
  160. use_cache=use_cache,
  161. position_embeddings=position_embeddings,
  162. **kwargs,
  163. )
  164. hidden_states = self.norm(hidden_states)
  165. return BaseModelOutputWithPast(
  166. last_hidden_state=hidden_states,
  167. past_key_values=past_key_values if use_cache else None,
  168. )
  169. class Starcoder2ForCausalLM(MistralForCausalLM):
  170. pass
  171. class Starcoder2ForSequenceClassification(MistralForSequenceClassification):
  172. pass
  173. class Starcoder2ForTokenClassification(MistralForTokenClassification):
  174. pass
  175. __all__ = [
  176. "Starcoder2ForCausalLM",
  177. "Starcoder2Model",
  178. "Starcoder2PreTrainedModel", # noqa: F822
  179. "Starcoder2ForSequenceClassification",
  180. "Starcoder2ForTokenClassification",
  181. ]