modeling_llama.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. from collections.abc import Callable
  20. from typing import Optional
  21. import torch
  22. from torch import nn
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache
  25. from ...generation import GenerationMixin
  26. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  27. from ...masking_utils import create_causal_mask
  28. from ...modeling_layers import (
  29. GenericForQuestionAnswering,
  30. GenericForSequenceClassification,
  31. GenericForTokenClassification,
  32. GradientCheckpointingLayer,
  33. )
  34. from ...modeling_outputs import (
  35. BaseModelOutputWithPast,
  36. CausalLMOutputWithPast,
  37. )
  38. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  39. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  40. from ...processing_utils import Unpack
  41. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  42. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  43. from ...utils.output_capturing import capture_outputs
  44. from .configuration_llama import LlamaConfig
  45. logger = logging.get_logger(__name__)
  46. @use_kernel_forward_from_hub("RMSNorm")
  47. class LlamaRMSNorm(nn.Module):
  48. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  49. """
  50. LlamaRMSNorm is equivalent to T5LayerNorm
  51. """
  52. super().__init__()
  53. self.weight = nn.Parameter(torch.ones(hidden_size))
  54. self.variance_epsilon = eps
  55. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  56. input_dtype = hidden_states.dtype
  57. hidden_states = hidden_states.to(torch.float32)
  58. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  59. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  60. return self.weight * hidden_states.to(input_dtype)
  61. def extra_repr(self):
  62. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  63. class LlamaRotaryEmbedding(nn.Module):
  64. inv_freq: torch.Tensor # fix linting for `register_buffer`
  65. def __init__(self, config: LlamaConfig, device=None):
  66. super().__init__()
  67. self.max_seq_len_cached = config.max_position_embeddings
  68. self.original_max_seq_len = config.max_position_embeddings
  69. self.config = config
  70. self.rope_type = self.config.rope_parameters["rope_type"]
  71. rope_init_fn: Callable = self.compute_default_rope_parameters
  72. if self.rope_type != "default":
  73. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  74. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  75. self.register_buffer("inv_freq", inv_freq, persistent=False)
  76. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  77. @staticmethod
  78. def compute_default_rope_parameters(
  79. config: LlamaConfig | None = None,
  80. device: Optional["torch.device"] = None,
  81. seq_len: int | None = None,
  82. ) -> tuple["torch.Tensor", float]:
  83. """
  84. Computes the inverse frequencies according to the original RoPE implementation
  85. Args:
  86. config ([`~transformers.PreTrainedConfig`]):
  87. The model configuration.
  88. device (`torch.device`):
  89. The device to use for initialization of the inverse frequencies.
  90. seq_len (`int`, *optional*):
  91. The current sequence length. Unused for this type of RoPE.
  92. Returns:
  93. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  94. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  95. """
  96. base = config.rope_parameters["rope_theta"]
  97. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  98. attention_factor = 1.0 # Unused in this type of RoPE
  99. # Compute the inverse frequencies
  100. inv_freq = 1.0 / (
  101. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  102. )
  103. return inv_freq, attention_factor
  104. @torch.no_grad()
  105. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  106. def forward(self, x, position_ids):
  107. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  108. position_ids_expanded = position_ids[:, None, :].float()
  109. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  110. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  111. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  112. emb = torch.cat((freqs, freqs), dim=-1)
  113. cos = emb.cos() * self.attention_scaling
  114. sin = emb.sin() * self.attention_scaling
  115. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  116. def rotate_half(x):
  117. """Rotates half the hidden dims of the input."""
  118. x1 = x[..., : x.shape[-1] // 2]
  119. x2 = x[..., x.shape[-1] // 2 :]
  120. return torch.cat((-x2, x1), dim=-1)
  121. @use_kernel_func_from_hub("rotary_pos_emb")
  122. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  123. """Applies Rotary Position Embedding to the query and key tensors.
  124. Args:
  125. q (`torch.Tensor`): The query tensor.
  126. k (`torch.Tensor`): The key tensor.
  127. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  128. sin (`torch.Tensor`): The sine part of the rotary embedding.
  129. unsqueeze_dim (`int`, *optional*, defaults to 1):
  130. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  131. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  132. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  133. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  134. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  135. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  136. Returns:
  137. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  138. """
  139. cos = cos.unsqueeze(unsqueeze_dim)
  140. sin = sin.unsqueeze(unsqueeze_dim)
  141. q_embed = (q * cos) + (rotate_half(q) * sin)
  142. k_embed = (k * cos) + (rotate_half(k) * sin)
  143. return q_embed, k_embed
  144. class LlamaMLP(nn.Module):
  145. def __init__(self, config):
  146. super().__init__()
  147. self.config = config
  148. self.hidden_size = config.hidden_size
  149. self.intermediate_size = config.intermediate_size
  150. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  151. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  152. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  153. self.act_fn = ACT2FN[config.hidden_act]
  154. def forward(self, x):
  155. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  156. return down_proj
  157. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  158. """
  159. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  160. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  161. """
  162. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  163. if n_rep == 1:
  164. return hidden_states
  165. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  166. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  167. def eager_attention_forward(
  168. module: nn.Module,
  169. query: torch.Tensor,
  170. key: torch.Tensor,
  171. value: torch.Tensor,
  172. attention_mask: torch.Tensor | None,
  173. scaling: float,
  174. dropout: float = 0.0,
  175. **kwargs: Unpack[TransformersKwargs],
  176. ):
  177. key_states = repeat_kv(key, module.num_key_value_groups)
  178. value_states = repeat_kv(value, module.num_key_value_groups)
  179. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  180. if attention_mask is not None:
  181. attn_weights = attn_weights + attention_mask
  182. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  183. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  184. attn_output = torch.matmul(attn_weights, value_states)
  185. attn_output = attn_output.transpose(1, 2).contiguous()
  186. return attn_output, attn_weights
  187. @use_kernelized_func(apply_rotary_pos_emb)
  188. class LlamaAttention(nn.Module):
  189. """Multi-headed attention from 'Attention Is All You Need' paper"""
  190. def __init__(self, config: LlamaConfig, layer_idx: int):
  191. super().__init__()
  192. self.config = config
  193. self.layer_idx = layer_idx
  194. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  195. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  196. self.scaling = self.head_dim**-0.5
  197. self.attention_dropout = config.attention_dropout
  198. self.is_causal = True
  199. self.q_proj = nn.Linear(
  200. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  201. )
  202. self.k_proj = nn.Linear(
  203. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  204. )
  205. self.v_proj = nn.Linear(
  206. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  207. )
  208. self.o_proj = nn.Linear(
  209. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  210. )
  211. def forward(
  212. self,
  213. hidden_states: torch.Tensor,
  214. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  215. attention_mask: torch.Tensor | None = None,
  216. past_key_values: Cache | None = None,
  217. **kwargs: Unpack[TransformersKwargs],
  218. ) -> tuple[torch.Tensor, torch.Tensor]:
  219. input_shape = hidden_states.shape[:-1]
  220. hidden_shape = (*input_shape, -1, self.head_dim)
  221. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  222. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  223. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  224. cos, sin = position_embeddings
  225. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  226. if past_key_values is not None:
  227. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  228. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  229. self.config._attn_implementation, eager_attention_forward
  230. )
  231. attn_output, attn_weights = attention_interface(
  232. self,
  233. query_states,
  234. key_states,
  235. value_states,
  236. attention_mask,
  237. dropout=0.0 if not self.training else self.attention_dropout,
  238. scaling=self.scaling,
  239. **kwargs,
  240. )
  241. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  242. attn_output = self.o_proj(attn_output)
  243. return attn_output, attn_weights
  244. class LlamaDecoderLayer(GradientCheckpointingLayer):
  245. def __init__(self, config: LlamaConfig, layer_idx: int):
  246. super().__init__()
  247. self.hidden_size = config.hidden_size
  248. self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
  249. self.mlp = LlamaMLP(config)
  250. self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  251. self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  252. def forward(
  253. self,
  254. hidden_states: torch.Tensor,
  255. attention_mask: torch.Tensor | None = None,
  256. position_ids: torch.LongTensor | None = None,
  257. past_key_values: Cache | None = None,
  258. use_cache: bool | None = False,
  259. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  260. **kwargs: Unpack[TransformersKwargs],
  261. ) -> torch.Tensor:
  262. residual = hidden_states
  263. hidden_states = self.input_layernorm(hidden_states)
  264. # Self Attention
  265. hidden_states, _ = self.self_attn(
  266. hidden_states=hidden_states,
  267. attention_mask=attention_mask,
  268. position_ids=position_ids,
  269. past_key_values=past_key_values,
  270. use_cache=use_cache,
  271. position_embeddings=position_embeddings,
  272. **kwargs,
  273. )
  274. hidden_states = residual + hidden_states
  275. # Fully Connected
  276. residual = hidden_states
  277. hidden_states = self.post_attention_layernorm(hidden_states)
  278. hidden_states = self.mlp(hidden_states)
  279. hidden_states = residual + hidden_states
  280. return hidden_states
  281. @auto_docstring
  282. class LlamaPreTrainedModel(PreTrainedModel):
  283. config: LlamaConfig
  284. base_model_prefix = "model"
  285. supports_gradient_checkpointing = True
  286. _no_split_modules = ["LlamaDecoderLayer"]
  287. _skip_keys_device_placement = ["past_key_values"]
  288. _supports_flash_attn = True
  289. _supports_sdpa = True
  290. _supports_flex_attn = True
  291. _can_compile_fullgraph = True
  292. _supports_attention_backend = True
  293. _can_record_outputs = {
  294. "hidden_states": LlamaDecoderLayer,
  295. "attentions": LlamaAttention,
  296. }
  297. @auto_docstring
  298. class LlamaModel(LlamaPreTrainedModel):
  299. def __init__(self, config: LlamaConfig):
  300. super().__init__(config)
  301. self.padding_idx = config.pad_token_id
  302. self.vocab_size = config.vocab_size
  303. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  304. self.layers = nn.ModuleList(
  305. [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  306. )
  307. self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  308. self.rotary_emb = LlamaRotaryEmbedding(config=config)
  309. self.gradient_checkpointing = False
  310. # Initialize weights and apply final processing
  311. self.post_init()
  312. @merge_with_config_defaults
  313. @capture_outputs
  314. @auto_docstring
  315. def forward(
  316. self,
  317. input_ids: torch.LongTensor | None = None,
  318. attention_mask: torch.Tensor | None = None,
  319. position_ids: torch.LongTensor | None = None,
  320. past_key_values: Cache | None = None,
  321. inputs_embeds: torch.FloatTensor | None = None,
  322. use_cache: bool | None = None,
  323. **kwargs: Unpack[TransformersKwargs],
  324. ) -> BaseModelOutputWithPast:
  325. if (input_ids is None) ^ (inputs_embeds is not None):
  326. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  327. if inputs_embeds is None:
  328. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  329. if use_cache and past_key_values is None:
  330. past_key_values = DynamicCache(config=self.config)
  331. if position_ids is None:
  332. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  333. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  334. position_ids = position_ids.unsqueeze(0)
  335. causal_mask = create_causal_mask(
  336. config=self.config,
  337. inputs_embeds=inputs_embeds,
  338. attention_mask=attention_mask,
  339. past_key_values=past_key_values,
  340. position_ids=position_ids,
  341. )
  342. hidden_states = inputs_embeds
  343. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  344. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  345. hidden_states = decoder_layer(
  346. hidden_states,
  347. attention_mask=causal_mask,
  348. position_embeddings=position_embeddings,
  349. position_ids=position_ids,
  350. past_key_values=past_key_values,
  351. use_cache=use_cache,
  352. **kwargs,
  353. )
  354. hidden_states = self.norm(hidden_states)
  355. return BaseModelOutputWithPast(
  356. last_hidden_state=hidden_states,
  357. past_key_values=past_key_values,
  358. )
  359. @auto_docstring
  360. class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
  361. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  362. _tp_plan = {"lm_head": "colwise_gather_output"}
  363. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  364. def __init__(self, config):
  365. super().__init__(config)
  366. self.model = LlamaModel(config)
  367. self.vocab_size = config.vocab_size
  368. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  369. # Initialize weights and apply final processing
  370. self.post_init()
  371. @can_return_tuple
  372. @auto_docstring
  373. def forward(
  374. self,
  375. input_ids: torch.LongTensor | None = None,
  376. attention_mask: torch.Tensor | None = None,
  377. position_ids: torch.LongTensor | None = None,
  378. past_key_values: Cache | None = None,
  379. inputs_embeds: torch.FloatTensor | None = None,
  380. labels: torch.LongTensor | None = None,
  381. use_cache: bool | None = None,
  382. logits_to_keep: int | torch.Tensor = 0,
  383. **kwargs: Unpack[TransformersKwargs],
  384. ) -> CausalLMOutputWithPast:
  385. r"""
  386. Example:
  387. ```python
  388. >>> from transformers import AutoTokenizer, LlamaForCausalLM
  389. >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
  390. >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
  391. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  392. >>> inputs = tokenizer(prompt, return_tensors="pt")
  393. >>> # Generate
  394. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  395. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  396. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  397. ```"""
  398. outputs: BaseModelOutputWithPast = self.model(
  399. input_ids=input_ids,
  400. attention_mask=attention_mask,
  401. position_ids=position_ids,
  402. past_key_values=past_key_values,
  403. inputs_embeds=inputs_embeds,
  404. use_cache=use_cache,
  405. **kwargs,
  406. )
  407. hidden_states = outputs.last_hidden_state
  408. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  409. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  410. logits = self.lm_head(hidden_states[:, slice_indices, :])
  411. loss = None
  412. if labels is not None:
  413. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  414. return CausalLMOutputWithPast(
  415. loss=loss,
  416. logits=logits,
  417. past_key_values=outputs.past_key_values,
  418. hidden_states=outputs.hidden_states,
  419. attentions=outputs.attentions,
  420. )
  421. class LlamaForSequenceClassification(GenericForSequenceClassification, LlamaPreTrainedModel): ...
  422. class LlamaForQuestionAnswering(GenericForQuestionAnswering, LlamaPreTrainedModel):
  423. base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
  424. class LlamaForTokenClassification(GenericForTokenClassification, LlamaPreTrainedModel): ...
  425. __all__ = [
  426. "LlamaForCausalLM",
  427. "LlamaModel",
  428. "LlamaPreTrainedModel",
  429. "LlamaForSequenceClassification",
  430. "LlamaForQuestionAnswering",
  431. "LlamaForTokenClassification",
  432. ]