modular_granite.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from torch import nn
  17. from ...cache_utils import Cache, DynamicCache
  18. from ...masking_utils import create_causal_mask
  19. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  20. from ...processing_utils import Unpack
  21. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  22. from ...utils.generic import merge_with_config_defaults
  23. from ...utils.output_capturing import capture_outputs
  24. from ..llama.modeling_llama import (
  25. LlamaAttention,
  26. LlamaDecoderLayer,
  27. LlamaForCausalLM,
  28. LlamaModel,
  29. LlamaPreTrainedModel,
  30. )
  31. from .configuration_granite import GraniteConfig
  32. logger = logging.get_logger(__name__)
  33. class GraniteAttention(LlamaAttention):
  34. """Multi-headed attention from 'Attention Is All You Need' paper"""
  35. def __init__(self, config: GraniteConfig, layer_idx: int | None = None):
  36. super().__init__(config, layer_idx)
  37. self.scaling = config.attention_multiplier
  38. class GraniteDecoderLayer(LlamaDecoderLayer):
  39. def __init__(self, config: GraniteConfig, layer_idx: int):
  40. super().__init__(config, layer_idx)
  41. self.residual_multiplier = config.residual_multiplier
  42. self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx)
  43. def forward(
  44. self,
  45. hidden_states: torch.Tensor,
  46. attention_mask: torch.Tensor | None = None,
  47. position_ids: torch.LongTensor | None = None,
  48. past_key_values: Cache | None = None,
  49. use_cache: bool | None = False,
  50. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  51. **kwargs: Unpack[TransformersKwargs],
  52. ) -> torch.Tensor:
  53. """
  54. Args:
  55. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  56. attention_mask (`torch.FloatTensor`, *optional*):
  57. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  58. query_sequence_length, key_sequence_length)` if default attention is used.
  59. output_attentions (`bool`, *optional*):
  60. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  61. returned tensors for more detail.
  62. use_cache (`bool`, *optional*):
  63. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  64. (see `past_key_values`).
  65. past_key_values (`Cache`, *optional*): cached past key and value projection states
  66. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  67. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  68. with `head_dim` being the embedding dimension of each attention head.
  69. kwargs (`dict`, *optional*):
  70. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  71. into the model
  72. """
  73. residual = hidden_states
  74. hidden_states = self.input_layernorm(hidden_states)
  75. hidden_states, _ = self.self_attn(
  76. hidden_states=hidden_states,
  77. attention_mask=attention_mask,
  78. position_ids=position_ids,
  79. past_key_values=past_key_values,
  80. use_cache=use_cache,
  81. position_embeddings=position_embeddings,
  82. **kwargs,
  83. )
  84. hidden_states = residual + hidden_states * self.residual_multiplier
  85. residual = hidden_states
  86. hidden_states = self.post_attention_layernorm(hidden_states)
  87. hidden_states = self.mlp(hidden_states)
  88. hidden_states = residual + hidden_states * self.residual_multiplier
  89. return hidden_states
  90. class GranitePreTrainedModel(LlamaPreTrainedModel):
  91. _can_record_outputs = {
  92. "hidden_states": GraniteDecoderLayer,
  93. "attentions": GraniteAttention,
  94. }
  95. class GraniteModel(LlamaModel):
  96. def __init__(self, config: GraniteConfig):
  97. super().__init__(config)
  98. self.embedding_multiplier = config.embedding_multiplier
  99. self.layers = nn.ModuleList(
  100. [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  101. )
  102. @merge_with_config_defaults
  103. @capture_outputs
  104. @auto_docstring
  105. def forward(
  106. self,
  107. input_ids: torch.LongTensor | None = None,
  108. attention_mask: torch.Tensor | None = None,
  109. position_ids: torch.LongTensor | None = None,
  110. past_key_values: Cache | None = None,
  111. inputs_embeds: torch.FloatTensor | None = None,
  112. use_cache: bool | None = None,
  113. **kwargs: Unpack[TransformersKwargs],
  114. ) -> BaseModelOutputWithPast:
  115. if (input_ids is None) ^ (inputs_embeds is not None):
  116. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  117. if inputs_embeds is None:
  118. inputs_embeds = self.embed_tokens(input_ids)
  119. inputs_embeds = inputs_embeds * self.embedding_multiplier
  120. if use_cache and past_key_values is None:
  121. past_key_values = DynamicCache(config=self.config)
  122. if position_ids is None:
  123. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  124. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  125. position_ids = position_ids.unsqueeze(0)
  126. causal_mask = create_causal_mask(
  127. config=self.config,
  128. inputs_embeds=inputs_embeds,
  129. attention_mask=attention_mask,
  130. past_key_values=past_key_values,
  131. position_ids=position_ids,
  132. )
  133. hidden_states = inputs_embeds
  134. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  135. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  136. hidden_states = decoder_layer(
  137. hidden_states,
  138. attention_mask=causal_mask,
  139. position_ids=position_ids,
  140. past_key_values=past_key_values,
  141. use_cache=use_cache,
  142. position_embeddings=position_embeddings,
  143. **kwargs,
  144. )
  145. hidden_states = self.norm(hidden_states)
  146. return BaseModelOutputWithPast(
  147. last_hidden_state=hidden_states,
  148. past_key_values=past_key_values,
  149. )
  150. class GraniteForCausalLM(LlamaForCausalLM):
  151. @can_return_tuple
  152. @auto_docstring
  153. def forward(
  154. self,
  155. input_ids: torch.LongTensor | None = None,
  156. attention_mask: torch.Tensor | None = None,
  157. position_ids: torch.LongTensor | None = None,
  158. past_key_values: Cache | None = None,
  159. inputs_embeds: torch.FloatTensor | None = None,
  160. labels: torch.LongTensor | None = None,
  161. use_cache: bool | None = None,
  162. logits_to_keep: int | torch.Tensor = 0,
  163. **kwargs: Unpack[TransformersKwargs],
  164. ) -> CausalLMOutputWithPast:
  165. outputs: BaseModelOutputWithPast = self.model(
  166. input_ids=input_ids,
  167. attention_mask=attention_mask,
  168. position_ids=position_ids,
  169. past_key_values=past_key_values,
  170. inputs_embeds=inputs_embeds,
  171. use_cache=use_cache,
  172. **kwargs,
  173. )
  174. hidden_states = outputs.last_hidden_state
  175. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  176. logits = self.lm_head(hidden_states[:, slice_indices, :])
  177. logits = logits / self.config.logits_scaling # main diff with Llama
  178. loss = None
  179. if labels is not None:
  180. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  181. return CausalLMOutputWithPast(
  182. loss=loss,
  183. logits=logits,
  184. past_key_values=outputs.past_key_values,
  185. hidden_states=outputs.hidden_states,
  186. attentions=outputs.attentions,
  187. )
  188. __all__ = ["GraniteForCausalLM", "GraniteModel", "GranitePreTrainedModel"]