modeling_falcon.py 54 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239
  1. # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Falcon model."""
  15. import math
  16. from collections.abc import Callable
  17. from typing import Optional
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
  21. from torch.nn import functional as F
  22. from ... import initialization as init
  23. from ...activations import get_activation
  24. from ...cache_utils import Cache, DynamicCache
  25. from ...generation import GenerationMixin
  26. from ...masking_utils import create_causal_mask
  27. from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import (
  30. BaseModelOutputWithPastAndCrossAttentions,
  31. CausalLMOutputWithCrossAttentions,
  32. QuestionAnsweringModelOutput,
  33. SequenceClassifierOutputWithPast,
  34. TokenClassifierOutput,
  35. )
  36. from ...modeling_rope_utils import (
  37. ROPE_INIT_FUNCTIONS,
  38. dynamic_rope_update,
  39. )
  40. from ...modeling_utils import PreTrainedModel
  41. from ...utils import (
  42. auto_docstring,
  43. logging,
  44. )
  45. from ...utils.generic import maybe_autocast
  46. from .configuration_falcon import FalconConfig
  47. if is_flash_attn_available():
  48. from ...modeling_flash_attention_utils import _flash_attention_forward
  49. logger = logging.get_logger(__name__)
  50. # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
  51. # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
  52. class FalconLinear(nn.Linear):
  53. def forward(self, input: torch.Tensor) -> torch.Tensor:
  54. hidden_states = input @ self.weight.T
  55. if self.bias is None:
  56. return hidden_states
  57. return hidden_states + self.bias
  58. # Copied from transformers.models.llama.modeling_llama.rotate_half
  59. def rotate_half(x):
  60. """Rotates half the hidden dims of the input."""
  61. x1 = x[..., : x.shape[-1] // 2]
  62. x2 = x[..., x.shape[-1] // 2 :]
  63. return torch.cat((-x2, x1), dim=-1)
  64. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  65. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  66. """Applies Rotary Position Embedding to the query and key tensors.
  67. Args:
  68. q (`torch.Tensor`): The query tensor.
  69. k (`torch.Tensor`): The key tensor.
  70. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  71. sin (`torch.Tensor`): The sine part of the rotary embedding.
  72. unsqueeze_dim (`int`, *optional*, defaults to 1):
  73. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  74. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  75. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  76. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  77. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  78. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  79. Returns:
  80. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  81. """
  82. cos = cos.unsqueeze(unsqueeze_dim)
  83. sin = sin.unsqueeze(unsqueeze_dim)
  84. q_embed = (q * cos) + (rotate_half(q) * sin)
  85. k_embed = (k * cos) + (rotate_half(k) * sin)
  86. return q_embed, k_embed
  87. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
  88. class FalconRotaryEmbedding(nn.Module):
  89. inv_freq: torch.Tensor # fix linting for `register_buffer`
  90. def __init__(self, config: FalconConfig, device=None):
  91. super().__init__()
  92. self.max_seq_len_cached = config.max_position_embeddings
  93. self.original_max_seq_len = config.max_position_embeddings
  94. self.config = config
  95. self.rope_type = self.config.rope_parameters["rope_type"]
  96. rope_init_fn: Callable = self.compute_default_rope_parameters
  97. if self.rope_type != "default":
  98. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  99. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  100. self.register_buffer("inv_freq", inv_freq, persistent=False)
  101. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  102. @staticmethod
  103. def compute_default_rope_parameters(
  104. config: FalconConfig | None = None,
  105. device: Optional["torch.device"] = None,
  106. seq_len: int | None = None,
  107. ) -> tuple["torch.Tensor", float]:
  108. """
  109. Computes the inverse frequencies according to the original RoPE implementation
  110. Args:
  111. config ([`~transformers.PreTrainedConfig`]):
  112. The model configuration.
  113. device (`torch.device`):
  114. The device to use for initialization of the inverse frequencies.
  115. seq_len (`int`, *optional*):
  116. The current sequence length. Unused for this type of RoPE.
  117. Returns:
  118. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  119. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  120. """
  121. base = config.rope_parameters["rope_theta"]
  122. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  123. attention_factor = 1.0 # Unused in this type of RoPE
  124. # Compute the inverse frequencies
  125. inv_freq = 1.0 / (
  126. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  127. )
  128. return inv_freq, attention_factor
  129. @torch.no_grad()
  130. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  131. def forward(self, x, position_ids):
  132. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  133. position_ids_expanded = position_ids[:, None, :].float()
  134. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  135. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  136. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  137. emb = torch.cat((freqs, freqs), dim=-1)
  138. cos = emb.cos() * self.attention_scaling
  139. sin = emb.sin() * self.attention_scaling
  140. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  141. def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
  142. batch_size, seq_length = attention_mask.shape
  143. closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
  144. base = torch.tensor(
  145. 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
  146. )
  147. powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
  148. slopes = torch.pow(base, powers)
  149. if closest_power_of_2 != num_heads:
  150. extra_base = torch.tensor(
  151. 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
  152. )
  153. num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
  154. extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
  155. slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
  156. # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
  157. # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
  158. # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
  159. # => the query_length dimension will then be broadcasted correctly
  160. # This is more or less identical to T5's relative position bias:
  161. # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
  162. arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
  163. alibi = slopes[..., None].bfloat16() * arange_tensor
  164. return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
  165. # Copied from transformers.models.bloom.modeling_bloom.dropout_add
  166. def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
  167. """
  168. Dropout add function
  169. Args:
  170. x (`torch.tensor`):
  171. input tensor
  172. residual (`torch.tensor`):
  173. residual tensor
  174. prob (`float`):
  175. dropout probability
  176. training (`bool`):
  177. training mode
  178. """
  179. out = F.dropout(x, p=prob, training=training)
  180. out = residual + out
  181. return out
  182. class FalconAttention(nn.Module):
  183. def __init__(self, config: FalconConfig, layer_idx=None):
  184. super().__init__()
  185. self.config = config
  186. self.hidden_size = config.hidden_size
  187. self.num_heads = config.num_attention_heads
  188. self.head_dim = self.hidden_size // self.num_heads
  189. self.split_size = self.hidden_size
  190. self.hidden_dropout = config.hidden_dropout
  191. self.max_position_embeddings = config.max_position_embeddings
  192. self.is_causal = True
  193. self.layer_idx = layer_idx
  194. if layer_idx is None:
  195. logger.warning_once(
  196. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  197. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  198. "when creating this class."
  199. )
  200. if self.head_dim * self.num_heads != self.hidden_size:
  201. raise ValueError(
  202. f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
  203. f" {self.num_heads})."
  204. )
  205. # Layer-wise attention scaling
  206. self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
  207. self.beta = self.inv_norm_factor
  208. if config.new_decoder_architecture:
  209. qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
  210. elif config.multi_query:
  211. qkv_out_dim = self.hidden_size + 2 * self.head_dim
  212. else:
  213. qkv_out_dim = 3 * self.hidden_size
  214. self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
  215. self.new_decoder_architecture = config.new_decoder_architecture
  216. self.multi_query = config.multi_query
  217. self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
  218. self.attention_dropout = nn.Dropout(config.attention_dropout)
  219. self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
  220. def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  221. """
  222. Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
  223. Args:
  224. fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
  225. Returns:
  226. query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
  227. value: [batch_size, seq_length, num_heads, head_dim]
  228. """
  229. if self.new_decoder_architecture:
  230. batch, seq_len, _ = fused_qkv.shape
  231. qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
  232. query = qkv[:, :, :, :-2]
  233. key = qkv[:, :, :, [-2]]
  234. value = qkv[:, :, :, [-1]]
  235. key = torch.broadcast_to(key, query.shape)
  236. value = torch.broadcast_to(value, query.shape)
  237. query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
  238. return query, key, value
  239. elif not self.multi_query:
  240. batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
  241. fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
  242. return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
  243. else:
  244. batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
  245. fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
  246. return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
  247. # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
  248. def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
  249. """
  250. Merge heads together over the last dimension
  251. Args:
  252. x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim]
  253. Returns:
  254. torch.tensor: [batch_size, seq_length, num_heads * head_dim]
  255. """
  256. # What we want to achieve is:
  257. # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
  258. batch_size_and_num_heads, seq_length, _ = x.shape
  259. batch_size = batch_size_and_num_heads // self.num_heads
  260. # First view to decompose the batch size
  261. # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
  262. x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
  263. # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
  264. x = x.permute(0, 2, 1, 3)
  265. # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
  266. return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
  267. def forward(
  268. self,
  269. hidden_states: torch.Tensor,
  270. alibi: torch.Tensor | None,
  271. attention_mask: torch.Tensor,
  272. position_ids: torch.LongTensor | None = None,
  273. layer_past: Cache | None = None,
  274. use_cache: bool = False,
  275. output_attentions: bool = False,
  276. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  277. **kwargs,
  278. ):
  279. fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
  280. num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
  281. # 3 x [batch_size, seq_length, num_heads, head_dim]
  282. (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
  283. batch_size, query_length, _, _ = query_layer.shape
  284. query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
  285. key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
  286. value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
  287. if alibi is None:
  288. cos, sin = position_embeddings
  289. query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
  290. if layer_past is not None:
  291. key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx)
  292. kv_length = key_layer.shape[-2]
  293. if alibi is None:
  294. if self.config._attn_implementation == "sdpa" and not output_attentions:
  295. # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
  296. # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
  297. # The query_length > 1 is necessary to match with a bidirectional attention mask we do not have
  298. # a causal pattern in those cases.
  299. is_causal = self.is_causal and attention_mask is None and query_length > 1
  300. attn_output = torch.nn.functional.scaled_dot_product_attention(
  301. query_layer,
  302. key_layer,
  303. value_layer,
  304. attn_mask=attention_mask,
  305. dropout_p=0.0,
  306. is_causal=is_causal,
  307. )
  308. attention_scores = None
  309. else:
  310. attention_scores = query_layer @ key_layer.transpose(-1, -2)
  311. attention_scores /= math.sqrt(self.head_dim)
  312. attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
  313. # It is unclear why dropout is not applied here (while it is with alibi).
  314. attn_output = attention_scores @ value_layer
  315. attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
  316. attn_output = attn_output.permute(0, 2, 1, 3)
  317. attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
  318. attn_output = self.dense(attn_output)
  319. return attn_output, attention_scores
  320. else:
  321. if self.config._attn_implementation == "sdpa" and not output_attentions:
  322. # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
  323. # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
  324. is_causal = self.is_causal and attention_mask is None and query_length > 1
  325. attn_output = torch.nn.functional.scaled_dot_product_attention(
  326. query_layer,
  327. key_layer,
  328. value_layer,
  329. attn_mask=attention_mask,
  330. dropout_p=self.attention_dropout.p if self.training else 0.0,
  331. is_causal=is_causal,
  332. )
  333. attention_probs = None
  334. attn_output = attn_output.transpose(1, 2)
  335. attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
  336. attn_output = self.dense(attn_output)
  337. else:
  338. matmul_result = query_layer @ key_layer.transpose(-1, -2)
  339. # change view to [batch_size, num_heads, q_length, kv_length]
  340. attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
  341. # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
  342. input_dtype = attention_scores.dtype
  343. # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
  344. if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
  345. attention_scores = attention_scores.to(torch.float32)
  346. attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
  347. attention_logits *= self.inv_norm_factor
  348. attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
  349. # [batch_size, num_heads, q_length, kv_length]
  350. attention_probs = self.attention_dropout(attention_probs)
  351. # change view [batch_size, num_heads, q_length, kv_length]
  352. attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
  353. # matmul: [batch_size * num_heads, q_length, head_dim]
  354. attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)
  355. # change view [batch_size, q_length, num_heads * head_dim]
  356. attn_output = self._merge_heads(attn_output)
  357. attn_output = self.dense(attn_output)
  358. return attn_output, attention_probs
  359. class FalconFlashAttention2(FalconAttention):
  360. """
  361. Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays
  362. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  363. flash attention and deal with padding tokens in case the input contains any of them.
  364. """
  365. def __init__(self, *args, **kwargs):
  366. super().__init__(*args, **kwargs)
  367. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  368. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  369. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  370. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  371. def forward(
  372. self,
  373. hidden_states: torch.Tensor,
  374. alibi: torch.Tensor | None,
  375. attention_mask: torch.Tensor,
  376. position_ids: torch.LongTensor | None = None,
  377. layer_past: Cache | None = None,
  378. use_cache: bool = False,
  379. output_attentions: bool = False,
  380. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  381. **kwargs,
  382. ):
  383. fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
  384. num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
  385. # 3 x [batch_size, seq_length, num_heads, head_dim]
  386. (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
  387. batch_size, query_length, _, _ = query_layer.shape
  388. query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
  389. key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
  390. value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
  391. if alibi is None:
  392. cos, sin = position_embeddings
  393. query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
  394. if layer_past is not None:
  395. key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx)
  396. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  397. # to be able to avoid many of these transpose/reshape/view.
  398. query_layer = query_layer.transpose(1, 2)
  399. key_layer = key_layer.transpose(1, 2)
  400. value_layer = value_layer.transpose(1, 2)
  401. if alibi is not None:
  402. raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
  403. attn_dropout = self.config.attention_dropout if self.training else 0.0
  404. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  405. # therefore the input hidden states gets silently casted in float32. Hence, we need
  406. # cast them back in float16 just to be sure everything works as expected.
  407. input_dtype = query_layer.dtype
  408. device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
  409. if input_dtype == torch.float32:
  410. if torch.is_autocast_enabled(device_type):
  411. target_dtype = torch.get_autocast_dtype(device_type)
  412. # Handle the case where the model is quantized
  413. elif hasattr(self.config, "_is_quantized"):
  414. target_dtype = self.config.dtype
  415. else:
  416. target_dtype = self.query_key_value.weight.dtype
  417. logger.warning_once(
  418. f"The input hidden states seems to be silently casted in float32, this might be related to"
  419. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  420. f" {target_dtype}."
  421. )
  422. query_layer = query_layer.to(target_dtype)
  423. key_layer = key_layer.to(target_dtype)
  424. value_layer = value_layer.to(target_dtype)
  425. attn_output = _flash_attention_forward(
  426. query_layer,
  427. key_layer,
  428. value_layer,
  429. attention_mask,
  430. query_length,
  431. position_ids=position_ids,
  432. dropout=attn_dropout,
  433. is_causal=self.is_causal,
  434. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  435. )
  436. attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
  437. attn_output = self.dense(attn_weights)
  438. if not output_attentions:
  439. attn_weights = None
  440. return attn_output, attn_weights
  441. class FalconMLP(nn.Module):
  442. def __init__(self, config: FalconConfig):
  443. super().__init__()
  444. hidden_size = config.hidden_size
  445. self.dense_h_to_4h = FalconLinear(hidden_size, config.ffn_hidden_size, bias=config.bias)
  446. self.act = get_activation(config.activation)
  447. self.dense_4h_to_h = FalconLinear(config.ffn_hidden_size, hidden_size, bias=config.bias)
  448. self.hidden_dropout = config.hidden_dropout
  449. def forward(self, x: torch.Tensor) -> torch.Tensor:
  450. x = self.act(self.dense_h_to_4h(x))
  451. x = self.dense_4h_to_h(x)
  452. return x
  453. FALCON_ATTENTION_CLASSES = {
  454. "eager": FalconAttention,
  455. "sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
  456. "flash_attention_2": FalconFlashAttention2,
  457. }
  458. class FalconDecoderLayer(GradientCheckpointingLayer):
  459. def __init__(self, config: FalconConfig, layer_idx=None):
  460. super().__init__()
  461. hidden_size = config.hidden_size
  462. self.num_heads = config.num_attention_heads
  463. self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
  464. self.mlp = FalconMLP(config)
  465. self.hidden_dropout = config.hidden_dropout
  466. self.config = config
  467. if config.num_ln_in_parallel_attn is None and config.new_decoder_architecture:
  468. config.num_ln_in_parallel_attn = 2
  469. if not config.parallel_attn:
  470. self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  471. self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  472. else:
  473. if config.num_ln_in_parallel_attn == 2:
  474. # The layer norm before self-attention
  475. self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  476. # The layer norm before the MLP
  477. self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  478. else:
  479. self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  480. def forward(
  481. self,
  482. hidden_states: torch.Tensor,
  483. alibi: torch.Tensor | None,
  484. attention_mask: torch.Tensor,
  485. position_ids: torch.LongTensor | None = None,
  486. layer_past: Cache | tuple[torch.Tensor, torch.Tensor] | None = None,
  487. use_cache: bool = False,
  488. output_attentions: bool = False,
  489. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  490. **kwargs,
  491. ):
  492. residual = hidden_states
  493. if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
  494. attention_layernorm_out = self.ln_attn(hidden_states)
  495. mlp_layernorm_out = self.ln_mlp(hidden_states)
  496. else:
  497. attention_layernorm_out = self.input_layernorm(hidden_states)
  498. # Self attention.
  499. attention_output, attn_weights = self.self_attention(
  500. attention_layernorm_out,
  501. layer_past=layer_past,
  502. attention_mask=attention_mask,
  503. position_ids=position_ids,
  504. alibi=alibi,
  505. use_cache=use_cache,
  506. output_attentions=output_attentions,
  507. position_embeddings=position_embeddings,
  508. )
  509. if not self.config.new_decoder_architecture:
  510. if self.config.parallel_attn:
  511. mlp_layernorm_out = attention_layernorm_out
  512. else:
  513. residual = dropout_add(
  514. attention_output, residual, self.config.attention_dropout, training=self.training
  515. )
  516. mlp_layernorm_out = self.post_attention_layernorm(residual)
  517. if (
  518. self.config.new_decoder_architecture
  519. and self.config.parallel_attn
  520. and self.config.num_ln_in_parallel_attn == 1
  521. ):
  522. mlp_layernorm_out = attention_layernorm_out
  523. # MLP.
  524. mlp_output = self.mlp(mlp_layernorm_out)
  525. if self.config.new_decoder_architecture or self.config.parallel_attn:
  526. mlp_output += attention_output
  527. output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
  528. return output, attn_weights
  529. @auto_docstring
  530. class FalconPreTrainedModel(PreTrainedModel):
  531. config: FalconConfig
  532. base_model_prefix = "transformer"
  533. supports_gradient_checkpointing = True
  534. _no_split_modules = ["FalconDecoderLayer"]
  535. _supports_flash_attn = True
  536. _supports_sdpa = True
  537. _can_compile_fullgraph = True
  538. @torch.no_grad()
  539. def _init_weights(self, module: nn.Module):
  540. """Initialize the weights."""
  541. super()._init_weights(module)
  542. if isinstance(module, FalconLinear):
  543. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  544. if module.bias is not None:
  545. init.zeros_(module.bias)
  546. # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
  547. @classmethod
  548. def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
  549. _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
  550. if _is_bettertransformer:
  551. return config
  552. if not hard_check_only:
  553. config._attn_implementation = "sdpa"
  554. return config
  555. @auto_docstring
  556. class FalconModel(FalconPreTrainedModel):
  557. def __init__(self, config: FalconConfig):
  558. super().__init__(config)
  559. self.embed_dim = config.hidden_size
  560. self.num_heads = config.num_attention_heads
  561. self.use_alibi = config.alibi
  562. # Embedding + LN Embedding
  563. self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
  564. # Transformer blocks
  565. self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  566. # Final Layer Norm
  567. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  568. self.gradient_checkpointing = False
  569. self.rotary_emb = FalconRotaryEmbedding(config=config)
  570. # Initialize weights and apply final processing
  571. self.post_init()
  572. def get_input_embeddings(self):
  573. return self.word_embeddings
  574. def set_input_embeddings(self, new_embeddings: torch.Tensor):
  575. self.word_embeddings = new_embeddings
  576. @auto_docstring
  577. def forward(
  578. self,
  579. input_ids: torch.LongTensor | None = None,
  580. past_key_values: Cache | None = None,
  581. attention_mask: torch.Tensor | None = None,
  582. position_ids: torch.LongTensor | None = None,
  583. inputs_embeds: torch.LongTensor | None = None,
  584. use_cache: bool | None = None,
  585. output_attentions: bool | None = None,
  586. output_hidden_states: bool | None = None,
  587. return_dict: bool | None = None,
  588. **kwargs,
  589. ) -> tuple[torch.Tensor, ...] | BaseModelOutputWithPastAndCrossAttentions:
  590. r"""
  591. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  592. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  593. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  594. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  595. `input_ids`.
  596. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  597. [`PreTrainedTokenizer.__call__`] for details.
  598. [What are input IDs?](../glossary#input-ids)
  599. """
  600. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  601. output_hidden_states = (
  602. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  603. )
  604. use_cache = use_cache if use_cache is not None else self.config.use_cache
  605. return_dict = return_dict if return_dict is not None else self.config.return_dict
  606. if (input_ids is None) ^ (inputs_embeds is not None):
  607. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  608. if self.gradient_checkpointing and self.training:
  609. if use_cache:
  610. logger.warning_once(
  611. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  612. )
  613. use_cache = False
  614. if inputs_embeds is None:
  615. inputs_embeds = self.word_embeddings(input_ids)
  616. if use_cache and past_key_values is None:
  617. past_key_values = DynamicCache(config=self.config)
  618. # Compute alibi tensor: check build_alibi_tensor documentation
  619. alibi = None
  620. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  621. batch_size, seq_length, _ = inputs_embeds.shape
  622. if self.use_alibi:
  623. mask = (
  624. torch.ones(
  625. (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
  626. )
  627. if attention_mask is None
  628. else attention_mask
  629. )
  630. alibi = build_alibi_tensor(mask, self.num_heads, dtype=inputs_embeds.dtype)
  631. if position_ids is None:
  632. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  633. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  634. position_ids = position_ids.unsqueeze(0)
  635. causal_mask = create_causal_mask(
  636. config=self.config,
  637. inputs_embeds=inputs_embeds,
  638. attention_mask=attention_mask,
  639. past_key_values=past_key_values,
  640. # Force mask creation for alibi
  641. and_mask_function=lambda *args: torch.tensor(True, dtype=torch.bool),
  642. )
  643. if alibi is not None and causal_mask is not None and causal_mask.ndim == 4:
  644. min_dtype = torch.finfo(inputs_embeds.dtype).min
  645. # Only using non-bool mask for alibi
  646. if causal_mask.dtype == torch.bool:
  647. causal_mask = torch.where(
  648. causal_mask, torch.tensor(0.0, device=causal_mask.device, dtype=inputs_embeds.dtype), min_dtype
  649. )
  650. # We take care to integrate alibi bias in the causal_mask here
  651. alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
  652. causal_mask = torch.masked_fill(
  653. alibi / math.sqrt(self.config.hidden_size // self.num_heads),
  654. causal_mask < -1,
  655. min_dtype,
  656. )
  657. hidden_states = inputs_embeds
  658. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  659. all_self_attentions = () if output_attentions else None
  660. all_hidden_states = () if output_hidden_states else None
  661. for i, block in enumerate(self.h):
  662. if output_hidden_states:
  663. all_hidden_states = all_hidden_states + (hidden_states,)
  664. outputs = block(
  665. hidden_states,
  666. layer_past=past_key_values,
  667. attention_mask=causal_mask,
  668. position_ids=position_ids,
  669. use_cache=use_cache,
  670. output_attentions=output_attentions,
  671. alibi=alibi,
  672. position_embeddings=position_embeddings,
  673. )
  674. hidden_states = outputs[0]
  675. if output_attentions:
  676. all_self_attentions = all_self_attentions + (outputs[1],)
  677. # Add last hidden state
  678. hidden_states = self.ln_f(hidden_states)
  679. if output_hidden_states:
  680. all_hidden_states = all_hidden_states + (hidden_states,)
  681. if not return_dict:
  682. return tuple(
  683. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  684. )
  685. return BaseModelOutputWithPastAndCrossAttentions(
  686. last_hidden_state=hidden_states,
  687. past_key_values=past_key_values,
  688. hidden_states=all_hidden_states,
  689. attentions=all_self_attentions,
  690. )
  691. @auto_docstring(
  692. custom_intro="""
  693. The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
  694. """
  695. )
  696. class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin):
  697. _tied_weights_keys = {"lm_head.weight": "transformer.word_embeddings.weight"}
  698. def __init__(self, config: FalconConfig):
  699. super().__init__(config)
  700. self.transformer = FalconModel(config)
  701. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  702. # Initialize weights and apply final processing
  703. self.post_init()
  704. def set_output_embeddings(self, new_embeddings: torch.Tensor):
  705. self.lm_head = new_embeddings
  706. @auto_docstring
  707. def forward(
  708. self,
  709. input_ids: torch.LongTensor | None = None,
  710. past_key_values: Cache | None = None,
  711. attention_mask: torch.Tensor | None = None,
  712. position_ids: torch.LongTensor | None = None,
  713. inputs_embeds: torch.Tensor | None = None,
  714. labels: torch.Tensor | None = None,
  715. use_cache: bool | None = None,
  716. output_attentions: bool | None = None,
  717. output_hidden_states: bool | None = None,
  718. return_dict: bool | None = None,
  719. logits_to_keep: int | torch.Tensor = 0,
  720. **kwargs,
  721. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  722. r"""
  723. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  724. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  725. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  726. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  727. `input_ids`.
  728. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  729. [`PreTrainedTokenizer.__call__`] for details.
  730. [What are input IDs?](../glossary#input-ids)
  731. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  732. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  733. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  734. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  735. """
  736. return_dict = return_dict if return_dict is not None else self.config.return_dict
  737. transformer_outputs = self.transformer(
  738. input_ids,
  739. past_key_values=past_key_values,
  740. attention_mask=attention_mask,
  741. position_ids=position_ids,
  742. inputs_embeds=inputs_embeds,
  743. use_cache=use_cache,
  744. output_attentions=output_attentions,
  745. output_hidden_states=output_hidden_states,
  746. return_dict=return_dict,
  747. )
  748. hidden_states = transformer_outputs[0]
  749. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  750. lm_logits = self.lm_head(hidden_states[:, slice_indices, :])
  751. loss = None
  752. if labels is not None:
  753. loss = self.loss_function(
  754. lm_logits,
  755. labels,
  756. vocab_size=self.config.vocab_size,
  757. **kwargs,
  758. )
  759. if not return_dict:
  760. output = (lm_logits,) + transformer_outputs[1:]
  761. return ((loss,) + output) if loss is not None else output
  762. return CausalLMOutputWithCrossAttentions(
  763. loss=loss,
  764. logits=lm_logits,
  765. past_key_values=transformer_outputs.past_key_values,
  766. hidden_states=transformer_outputs.hidden_states,
  767. attentions=transformer_outputs.attentions,
  768. )
  769. @auto_docstring(
  770. custom_intro="""
  771. The Falcon Model transformer with a sequence classification head on top (linear layer).
  772. [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  773. (e.g. GPT-1) do.
  774. Since it does classification on the last token, it requires to know the position of the last token. If a
  775. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  776. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  777. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  778. each row of the batch).
  779. """
  780. )
  781. class FalconForSequenceClassification(FalconPreTrainedModel):
  782. def __init__(self, config: FalconConfig):
  783. super().__init__(config)
  784. self.num_labels = config.num_labels
  785. self.transformer = FalconModel(config)
  786. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  787. # Initialize weights and apply final processing
  788. self.post_init()
  789. @auto_docstring
  790. def forward(
  791. self,
  792. input_ids: torch.LongTensor | None = None,
  793. past_key_values: Cache | None = None,
  794. attention_mask: torch.Tensor | None = None,
  795. inputs_embeds: torch.Tensor | None = None,
  796. labels: torch.Tensor | None = None,
  797. use_cache: bool | None = None,
  798. output_attentions: bool | None = None,
  799. output_hidden_states: bool | None = None,
  800. return_dict: bool | None = None,
  801. **kwargs,
  802. ) -> tuple[torch.Tensor] | SequenceClassifierOutputWithPast:
  803. r"""
  804. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  805. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  806. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  807. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  808. `input_ids`.
  809. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  810. [`PreTrainedTokenizer.__call__`] for details.
  811. [What are input IDs?](../glossary#input-ids)
  812. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  813. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  814. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  815. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  816. """
  817. return_dict = return_dict if return_dict is not None else self.config.return_dict
  818. transformer_outputs = self.transformer(
  819. input_ids,
  820. past_key_values=past_key_values,
  821. attention_mask=attention_mask,
  822. inputs_embeds=inputs_embeds,
  823. use_cache=use_cache,
  824. output_attentions=output_attentions,
  825. output_hidden_states=output_hidden_states,
  826. return_dict=return_dict,
  827. )
  828. hidden_states = transformer_outputs[0]
  829. logits = self.score(hidden_states)
  830. if input_ids is not None:
  831. batch_size = input_ids.shape[0]
  832. else:
  833. batch_size = inputs_embeds.shape[0]
  834. if self.config.pad_token_id is None and batch_size != 1:
  835. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  836. if self.config.pad_token_id is None:
  837. last_non_pad_token = -1
  838. elif input_ids is not None:
  839. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  840. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  841. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  842. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  843. else:
  844. last_non_pad_token = -1
  845. logger.warning_once(
  846. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  847. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  848. )
  849. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  850. loss = None
  851. if labels is not None:
  852. if self.config.problem_type is None:
  853. if self.num_labels == 1:
  854. self.config.problem_type = "regression"
  855. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  856. self.config.problem_type = "single_label_classification"
  857. else:
  858. self.config.problem_type = "multi_label_classification"
  859. if self.config.problem_type == "regression":
  860. loss_fct = MSELoss()
  861. if self.num_labels == 1:
  862. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  863. else:
  864. loss = loss_fct(pooled_logits, labels)
  865. elif self.config.problem_type == "single_label_classification":
  866. loss_fct = CrossEntropyLoss()
  867. loss = loss_fct(pooled_logits, labels)
  868. elif self.config.problem_type == "multi_label_classification":
  869. loss_fct = BCEWithLogitsLoss()
  870. loss = loss_fct(pooled_logits, labels)
  871. if not return_dict:
  872. output = (pooled_logits,) + transformer_outputs[1:]
  873. return ((loss,) + output) if loss is not None else output
  874. return SequenceClassifierOutputWithPast(
  875. loss=loss,
  876. logits=pooled_logits,
  877. past_key_values=transformer_outputs.past_key_values,
  878. hidden_states=transformer_outputs.hidden_states,
  879. attentions=transformer_outputs.attentions,
  880. )
  881. @auto_docstring
  882. class FalconForTokenClassification(FalconPreTrainedModel):
  883. def __init__(self, config: FalconConfig):
  884. super().__init__(config)
  885. self.num_labels = config.num_labels
  886. self.transformer = FalconModel(config)
  887. if getattr(config, "classifier_dropout", None) is not None:
  888. classifier_dropout = config.classifier_dropout
  889. elif getattr(config, "hidden_dropout", None) is not None:
  890. classifier_dropout = config.hidden_dropout
  891. else:
  892. classifier_dropout = 0.1
  893. self.dropout = nn.Dropout(classifier_dropout)
  894. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  895. # Initialize weights and apply final processing
  896. self.post_init()
  897. @auto_docstring
  898. def forward(
  899. self,
  900. input_ids: torch.LongTensor | None = None,
  901. past_key_values: Cache | None = None,
  902. attention_mask: torch.Tensor | None = None,
  903. inputs_embeds: torch.Tensor | None = None,
  904. labels: torch.Tensor | None = None,
  905. use_cache: bool | None = None,
  906. output_attentions: bool | None = None,
  907. output_hidden_states: bool | None = None,
  908. return_dict: bool | None = None,
  909. **kwargs,
  910. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  911. r"""
  912. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  913. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  914. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  915. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  916. `input_ids`.
  917. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  918. [`PreTrainedTokenizer.__call__`] for details.
  919. [What are input IDs?](../glossary#input-ids)
  920. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  921. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  922. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  923. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  924. """
  925. return_dict = return_dict if return_dict is not None else self.config.return_dict
  926. transformer_outputs = self.transformer(
  927. input_ids,
  928. past_key_values=past_key_values,
  929. attention_mask=attention_mask,
  930. inputs_embeds=inputs_embeds,
  931. use_cache=use_cache,
  932. output_attentions=output_attentions,
  933. output_hidden_states=output_hidden_states,
  934. return_dict=return_dict,
  935. )
  936. hidden_states = transformer_outputs[0]
  937. hidden_states = self.dropout(hidden_states)
  938. logits = self.classifier(hidden_states)
  939. loss = None
  940. if labels is not None:
  941. batch_size, seq_length = labels.shape
  942. loss_fct = CrossEntropyLoss()
  943. loss = loss_fct(
  944. logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
  945. )
  946. if not return_dict:
  947. output = (logits,) + transformer_outputs[2:]
  948. return ((loss,) + output) if loss is not None else output
  949. return TokenClassifierOutput(
  950. loss=loss,
  951. logits=logits,
  952. hidden_states=transformer_outputs.hidden_states,
  953. attentions=transformer_outputs.attentions,
  954. )
  955. @auto_docstring
  956. class FalconForQuestionAnswering(FalconPreTrainedModel):
  957. def __init__(self, config):
  958. super().__init__(config)
  959. self.transformer = FalconModel(config)
  960. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  961. # Initialize weights and apply final processing
  962. self.post_init()
  963. @auto_docstring
  964. def forward(
  965. self,
  966. input_ids: torch.LongTensor | None = None,
  967. attention_mask: torch.FloatTensor | None = None,
  968. inputs_embeds: torch.FloatTensor | None = None,
  969. start_positions: torch.LongTensor | None = None,
  970. end_positions: torch.LongTensor | None = None,
  971. output_attentions: bool | None = None,
  972. output_hidden_states: bool | None = None,
  973. return_dict: bool | None = None,
  974. **kwargs,
  975. ) -> tuple | QuestionAnsweringModelOutput:
  976. r"""
  977. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  978. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  979. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  980. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  981. `input_ids`.
  982. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  983. [`PreTrainedTokenizer.__call__`] for details.
  984. [What are input IDs?](../glossary#input-ids)
  985. """
  986. return_dict = return_dict if return_dict is not None else self.config.return_dict
  987. outputs = self.transformer(
  988. input_ids,
  989. attention_mask=attention_mask,
  990. inputs_embeds=inputs_embeds,
  991. output_attentions=output_attentions,
  992. output_hidden_states=output_hidden_states,
  993. return_dict=return_dict,
  994. )
  995. sequence_output = outputs[0]
  996. logits = self.qa_outputs(sequence_output)
  997. start_logits, end_logits = logits.split(1, dim=-1)
  998. start_logits = start_logits.squeeze(-1).contiguous()
  999. end_logits = end_logits.squeeze(-1).contiguous()
  1000. total_loss = None
  1001. if start_positions is not None and end_positions is not None:
  1002. # If we are on multi-GPU, split add a dimension
  1003. if len(start_positions.size()) > 1:
  1004. start_positions = start_positions.squeeze(-1)
  1005. if len(end_positions.size()) > 1:
  1006. end_positions = end_positions.squeeze(-1)
  1007. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1008. ignored_index = start_logits.size(1)
  1009. start_positions = start_positions.clamp(0, ignored_index)
  1010. end_positions = end_positions.clamp(0, ignored_index)
  1011. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1012. start_loss = loss_fct(start_logits, start_positions)
  1013. end_loss = loss_fct(end_logits, end_positions)
  1014. total_loss = (start_loss + end_loss) / 2
  1015. if not return_dict:
  1016. output = (start_logits, end_logits) + outputs[2:]
  1017. return ((total_loss,) + output) if total_loss is not None else output
  1018. return QuestionAnsweringModelOutput(
  1019. loss=total_loss,
  1020. start_logits=start_logits,
  1021. end_logits=end_logits,
  1022. hidden_states=outputs.hidden_states,
  1023. attentions=outputs.attentions,
  1024. )
  1025. __all__ = [
  1026. "FalconForCausalLM",
  1027. "FalconModel",
  1028. "FalconPreTrainedModel",
  1029. "FalconForSequenceClassification",
  1030. "FalconForTokenClassification",
  1031. "FalconForQuestionAnswering",
  1032. ]