modeling_bloom.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch BLOOM model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
  19. from torch.nn import functional as F
  20. from ...cache_utils import Cache, DynamicCache, StaticCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithPastAndCrossAttentions,
  26. CausalLMOutputWithCrossAttentions,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutputWithPast,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...utils import (
  33. auto_docstring,
  34. logging,
  35. )
  36. from .configuration_bloom import BloomConfig
  37. logger = logging.get_logger(__name__)
  38. def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
  39. """
  40. Link to paper: https://huggingface.co/papers/2108.12409 Alibi tensor is not causal as the original paper mentions, it
  41. relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
  42. `softmax(l+a) = softmax(l)`. Based on
  43. https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
  44. TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
  45. Args:
  46. Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
  47. attention_mask (`torch.Tensor`):
  48. Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
  49. num_heads (`int`):
  50. number of heads
  51. dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
  52. dtype of the output tensor
  53. """
  54. batch_size, seq_length = attention_mask.shape
  55. closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
  56. base = torch.tensor(
  57. 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
  58. )
  59. powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
  60. slopes = torch.pow(base, powers)
  61. if closest_power_of_2 != num_heads:
  62. extra_base = torch.tensor(
  63. 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
  64. )
  65. num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
  66. extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
  67. slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
  68. # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
  69. # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
  70. # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
  71. # => the query_length dimension will then be broadcasted correctly
  72. # This is more or less identical to T5's relative position bias:
  73. # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
  74. arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
  75. alibi = slopes[..., None] * arange_tensor
  76. return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
  77. def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
  78. """
  79. Dropout add function
  80. Args:
  81. x (`torch.tensor`):
  82. input tensor
  83. residual (`torch.tensor`):
  84. residual tensor
  85. prob (`float`):
  86. dropout probability
  87. training (`bool`):
  88. training mode
  89. """
  90. out = F.dropout(x, p=prob, training=training)
  91. out = residual + out
  92. return out
  93. def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
  94. """
  95. Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
  96. make the model jitable.
  97. Args:
  98. x (`torch.tensor`):
  99. input hidden states
  100. """
  101. return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
  102. def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
  103. """
  104. gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
  105. 0.3989423 * x * torch.exp(-0.5 * x * x)
  106. Args:
  107. g (`torch.tensor`):
  108. gradient output tensor
  109. x (`torch.tensor`):
  110. input tensor
  111. """
  112. x = x[0] # x is a tuple of 1 element, needs to unpack it first
  113. tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
  114. # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
  115. ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
  116. return ff * g
  117. class GeLUFunction(torch.autograd.Function):
  118. @staticmethod
  119. def forward(ctx, input: torch.Tensor) -> torch.Tensor:
  120. ctx.save_for_backward(input)
  121. return bloom_gelu_forward(input)
  122. @staticmethod
  123. def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
  124. input = ctx.saved_tensors
  125. tmp = bloom_gelu_back(grad_output, input)
  126. return tmp
  127. class BloomGelu(nn.Module):
  128. """
  129. Partly copied from Megatron-DeepSpeed code and adapted for our needs
  130. """
  131. def __init__(self):
  132. super().__init__()
  133. def forward(self, x: torch.Tensor) -> torch.Tensor:
  134. return GeLUFunction.apply(x)
  135. class BloomAttention(nn.Module):
  136. def __init__(self, config: BloomConfig, layer_idx: int | None = None):
  137. super().__init__()
  138. self.pretraining_tp = config.pretraining_tp
  139. self.slow_but_exact = config.slow_but_exact
  140. self.hidden_size = config.hidden_size
  141. self.num_heads = config.n_head
  142. self.head_dim = self.hidden_size // self.num_heads
  143. self.split_size = self.hidden_size
  144. self.hidden_dropout = config.hidden_dropout
  145. if self.head_dim * self.num_heads != self.hidden_size:
  146. raise ValueError(
  147. f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
  148. f" {self.num_heads})."
  149. )
  150. # Layer-wise attention scaling
  151. self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
  152. self.beta = 1.0
  153. self.layer_idx = layer_idx
  154. if layer_idx is None:
  155. logger.warning_once(
  156. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  157. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  158. "when creating this class."
  159. )
  160. self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
  161. self.dense = nn.Linear(self.hidden_size, self.hidden_size)
  162. self.attention_dropout = nn.Dropout(config.attention_dropout)
  163. def _reshape(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  164. """
  165. Split the last dimension into (num_heads, head_dim) and reshapes to (bs, heads, len, dim) shape
  166. without making any copies, results share same memory storage as `fused_qkv`
  167. Args:
  168. fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
  169. Returns:
  170. query: [batch_size, num_heads, seq_length, head_dim]
  171. key: [batch_size, num_heads, seq_length, head_dim]
  172. value: [batch_size, num_heads, seq_length, head_dim]
  173. """
  174. batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
  175. fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
  176. query_layer = fused_qkv[..., 0, :].transpose(1, 2)
  177. key_layer = fused_qkv[..., 1, :].transpose(1, 2)
  178. value_layer = fused_qkv[..., 2, :].transpose(1, 2)
  179. return query_layer, key_layer, value_layer
  180. def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
  181. """
  182. Merge heads together over the last dimension
  183. Args:
  184. x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim]
  185. Returns:
  186. torch.tensor: [batch_size, seq_length, num_heads * head_dim]
  187. """
  188. # What we want to achieve is:
  189. # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
  190. batch_size_and_num_heads, seq_length, _ = x.shape
  191. batch_size = batch_size_and_num_heads // self.num_heads
  192. # First view to decompose the batch size
  193. # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
  194. x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
  195. # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
  196. x = x.permute(0, 2, 1, 3)
  197. # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
  198. return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
  199. def forward(
  200. self,
  201. hidden_states: torch.Tensor,
  202. residual: torch.Tensor,
  203. alibi: torch.Tensor,
  204. attention_mask: torch.Tensor,
  205. layer_past: Cache | None = None,
  206. use_cache: bool = False,
  207. output_attentions: bool = False,
  208. **kwargs,
  209. ):
  210. batch_size, q_length, _ = hidden_states.shape
  211. fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
  212. # 3 x [batch_size, num_heads, seq_length, head_dim]
  213. query_layer, key_layer, value_layer = self._reshape(fused_qkv)
  214. if layer_past is not None:
  215. key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx)
  216. # reshape qkv for further computations
  217. query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
  218. key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
  219. value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
  220. # [batch_size * num_heads, q_length, kv_length]
  221. attention_scores = alibi.baddbmm(
  222. batch1=query_layer,
  223. batch2=key_layer,
  224. beta=self.beta,
  225. alpha=self.inv_norm_factor,
  226. )
  227. # change view to [batch_size, num_heads, q_length, kv_length]
  228. attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
  229. if attention_mask is not None:
  230. attn_weights = attn_weights + attention_mask
  231. # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
  232. attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
  233. # [batch_size, num_heads, q_length, kv_length]
  234. attention_probs = self.attention_dropout(attention_probs)
  235. # change view [batch_size x num_heads, q_length, kv_length]
  236. attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
  237. # matmul: [batch_size * num_heads, q_length, head_dim]
  238. context_layer = torch.bmm(attention_probs_reshaped, value_layer)
  239. # change view [batch_size, q_length, num_heads * head_dim]
  240. context_layer = self._merge_heads(context_layer)
  241. # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
  242. if self.pretraining_tp > 1 and self.slow_but_exact:
  243. slices = self.hidden_size / self.pretraining_tp
  244. output_tensor = torch.zeros_like(context_layer)
  245. for i in range(self.pretraining_tp):
  246. output_tensor = output_tensor + F.linear(
  247. context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
  248. self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
  249. )
  250. else:
  251. output_tensor = self.dense(context_layer)
  252. output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
  253. return output_tensor, attention_probs
  254. class BloomMLP(nn.Module):
  255. def __init__(self, config: BloomConfig):
  256. super().__init__()
  257. hidden_size = config.hidden_size
  258. self.pretraining_tp = config.pretraining_tp
  259. self.slow_but_exact = config.slow_but_exact
  260. self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
  261. self.gelu_impl = BloomGelu()
  262. self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
  263. self.hidden_dropout = config.hidden_dropout
  264. def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
  265. hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
  266. if self.pretraining_tp > 1 and self.slow_but_exact:
  267. intermediate_output = torch.zeros_like(residual)
  268. slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
  269. for i in range(self.pretraining_tp):
  270. intermediate_output = intermediate_output + F.linear(
  271. hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
  272. self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
  273. )
  274. else:
  275. intermediate_output = self.dense_4h_to_h(hidden_states)
  276. output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
  277. return output
  278. class BloomBlock(GradientCheckpointingLayer):
  279. def __init__(self, config: BloomConfig, layer_idx: int | None = None):
  280. super().__init__()
  281. hidden_size = config.hidden_size
  282. self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  283. self.num_heads = config.n_head
  284. self.self_attention = BloomAttention(config, layer_idx)
  285. self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  286. self.mlp = BloomMLP(config)
  287. self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
  288. self.hidden_dropout = config.hidden_dropout
  289. def forward(
  290. self,
  291. hidden_states: torch.Tensor,
  292. alibi: torch.Tensor,
  293. attention_mask: torch.Tensor,
  294. layer_past: Cache | None = None,
  295. use_cache: bool = False,
  296. output_attentions: bool = False,
  297. **kwargs,
  298. ):
  299. # hidden_states: [batch_size, seq_length, hidden_size]
  300. # Layer norm at the beginning of the transformer layer.
  301. layernorm_output = self.input_layernorm(hidden_states)
  302. # Layer norm post the self attention.
  303. if self.apply_residual_connection_post_layernorm:
  304. residual = layernorm_output
  305. else:
  306. residual = hidden_states
  307. # Self attention.
  308. attention_output, attn_weights = self.self_attention(
  309. layernorm_output,
  310. residual,
  311. layer_past=layer_past,
  312. attention_mask=attention_mask,
  313. alibi=alibi,
  314. use_cache=use_cache,
  315. output_attentions=output_attentions,
  316. )
  317. layernorm_output = self.post_attention_layernorm(attention_output)
  318. # Get residual
  319. if self.apply_residual_connection_post_layernorm:
  320. residual = layernorm_output
  321. else:
  322. residual = attention_output
  323. # MLP.
  324. output = self.mlp(layernorm_output, residual)
  325. return output, attn_weights # hidden_states, attentions
  326. @auto_docstring
  327. class BloomPreTrainedModel(PreTrainedModel):
  328. config: BloomConfig
  329. base_model_prefix = "transformer"
  330. supports_gradient_checkpointing = True
  331. _no_split_modules = ["BloomBlock"]
  332. _skip_keys_device_placement = "past_key_values"
  333. _can_compile_fullgraph = True
  334. @auto_docstring
  335. class BloomModel(BloomPreTrainedModel):
  336. def __init__(self, config: BloomConfig):
  337. super().__init__(config)
  338. self.embed_dim = config.hidden_size
  339. self.num_heads = config.n_head
  340. # Embedding + LN Embedding
  341. self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
  342. self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  343. # Transformer blocks
  344. self.h = nn.ModuleList([BloomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  345. # Final Layer Norm
  346. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  347. self.gradient_checkpointing = False
  348. # Initialize weights and apply final processing
  349. self.post_init()
  350. def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
  351. return build_alibi_tensor(attention_mask, num_heads, dtype)
  352. def get_input_embeddings(self):
  353. return self.word_embeddings
  354. def set_input_embeddings(self, new_embeddings: torch.Tensor):
  355. self.word_embeddings = new_embeddings
  356. @auto_docstring
  357. def forward(
  358. self,
  359. input_ids: torch.LongTensor | None = None,
  360. past_key_values: Cache | None = None,
  361. attention_mask: torch.Tensor | None = None,
  362. inputs_embeds: torch.LongTensor | None = None,
  363. use_cache: bool | None = None,
  364. output_attentions: bool | None = None,
  365. output_hidden_states: bool | None = None,
  366. return_dict: bool | None = None,
  367. **kwargs,
  368. ) -> tuple[torch.Tensor, ...] | BaseModelOutputWithPastAndCrossAttentions:
  369. r"""
  370. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  371. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  372. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  373. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  374. `input_ids`.
  375. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  376. [`PreTrainedTokenizer.__call__`] for details.
  377. [What are input IDs?](../glossary#input-ids)
  378. """
  379. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  380. output_hidden_states = (
  381. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  382. )
  383. use_cache = use_cache if use_cache is not None else self.config.use_cache
  384. return_dict = return_dict if return_dict is not None else self.config.return_dict
  385. if (input_ids is None) ^ (inputs_embeds is not None):
  386. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  387. if self.gradient_checkpointing and self.training and use_cache:
  388. logger.warning_once(
  389. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  390. )
  391. use_cache = False
  392. if inputs_embeds is None:
  393. inputs_embeds = self.word_embeddings(input_ids)
  394. if use_cache and past_key_values is None:
  395. past_key_values = DynamicCache(config=self.config)
  396. batch_size, seq_length, _ = inputs_embeds.shape
  397. past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  398. seq_length_with_past = seq_length + past_length
  399. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  400. all_self_attentions = () if output_attentions else None
  401. all_hidden_states = () if output_hidden_states else None
  402. # Compute alibi tensor: check build_alibi_tensor documentation
  403. if attention_mask is None:
  404. attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
  405. else:
  406. attention_mask = attention_mask.to(hidden_states.device)
  407. alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
  408. causal_mask = create_causal_mask(
  409. config=self.config,
  410. inputs_embeds=inputs_embeds,
  411. attention_mask=attention_mask,
  412. past_key_values=past_key_values,
  413. )
  414. for i, block in enumerate(self.h):
  415. if output_hidden_states:
  416. all_hidden_states = all_hidden_states + (hidden_states,)
  417. outputs = block(
  418. hidden_states,
  419. layer_past=past_key_values,
  420. attention_mask=causal_mask,
  421. use_cache=use_cache,
  422. output_attentions=output_attentions,
  423. alibi=alibi,
  424. )
  425. hidden_states = outputs[0]
  426. if output_attentions:
  427. all_self_attentions = all_self_attentions + (outputs[1],)
  428. # Add last hidden state
  429. hidden_states = self.ln_f(hidden_states)
  430. if output_hidden_states:
  431. all_hidden_states = all_hidden_states + (hidden_states,)
  432. if not return_dict:
  433. return tuple(
  434. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  435. )
  436. return BaseModelOutputWithPastAndCrossAttentions(
  437. last_hidden_state=hidden_states,
  438. past_key_values=past_key_values,
  439. hidden_states=all_hidden_states,
  440. attentions=all_self_attentions,
  441. )
  442. @auto_docstring(
  443. custom_intro="""
  444. The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
  445. embeddings).
  446. """
  447. )
  448. class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
  449. _tied_weights_keys = {"lm_head.weight": "transformer.word_embeddings.weight"}
  450. def __init__(self, config: BloomConfig):
  451. super().__init__(config)
  452. self.transformer = BloomModel(config)
  453. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  454. # Initialize weights and apply final processing
  455. self.post_init()
  456. def set_output_embeddings(self, new_embeddings: torch.Tensor):
  457. self.lm_head = new_embeddings
  458. def prepare_inputs_for_generation(
  459. self,
  460. input_ids,
  461. past_key_values=None,
  462. attention_mask=None,
  463. inputs_embeds=None,
  464. use_cache=True,
  465. is_first_iteration=False,
  466. **kwargs,
  467. ):
  468. # Overwritten because of the fixed-shape attention mask creation
  469. model_inputs = super().prepare_inputs_for_generation(
  470. input_ids,
  471. past_key_values=past_key_values,
  472. attention_mask=attention_mask,
  473. inputs_embeds=inputs_embeds,
  474. use_cache=use_cache,
  475. is_first_iteration=is_first_iteration,
  476. **kwargs,
  477. )
  478. # This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
  479. # The only difference is the usage of 2D instead of 4D mask, but the shape will be static
  480. if isinstance(past_key_values, StaticCache) and attention_mask is not None:
  481. target_length = past_key_values.get_max_cache_shape()
  482. batch_size, seq_length = attention_mask.shape
  483. diff = target_length - seq_length
  484. new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
  485. attention_mask = torch.cat([attention_mask, new_attn_mask], dim=-1)
  486. model_inputs["attention_mask"] = attention_mask
  487. return model_inputs
  488. @auto_docstring
  489. def forward(
  490. self,
  491. input_ids: torch.LongTensor | None = None,
  492. past_key_values: Cache | None = None,
  493. attention_mask: torch.Tensor | None = None,
  494. inputs_embeds: torch.Tensor | None = None,
  495. labels: torch.Tensor | None = None,
  496. use_cache: bool | None = None,
  497. output_attentions: bool | None = None,
  498. output_hidden_states: bool | None = None,
  499. return_dict: bool | None = None,
  500. logits_to_keep: int | torch.Tensor = 0,
  501. **kwargs,
  502. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  503. r"""
  504. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  505. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  506. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  507. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  508. `input_ids`.
  509. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  510. [`PreTrainedTokenizer.__call__`] for details.
  511. [What are input IDs?](../glossary#input-ids)
  512. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  513. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  514. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  515. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  516. """
  517. return_dict = return_dict if return_dict is not None else self.config.return_dict
  518. transformer_outputs = self.transformer(
  519. input_ids,
  520. past_key_values=past_key_values,
  521. attention_mask=attention_mask,
  522. inputs_embeds=inputs_embeds,
  523. use_cache=use_cache,
  524. output_attentions=output_attentions,
  525. output_hidden_states=output_hidden_states,
  526. return_dict=return_dict,
  527. )
  528. hidden_states = transformer_outputs[0]
  529. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  530. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  531. logits = self.lm_head(hidden_states[:, slice_indices, :])
  532. loss = None
  533. if labels is not None:
  534. loss = self.loss_function(
  535. logits,
  536. labels,
  537. vocab_size=self.config.vocab_size,
  538. num_items_in_batch=kwargs.get("num_items_in_batch"),
  539. )
  540. if not return_dict:
  541. output = (logits,) + transformer_outputs[1:]
  542. return ((loss,) + output) if loss is not None else output
  543. return CausalLMOutputWithCrossAttentions(
  544. loss=loss,
  545. logits=logits,
  546. past_key_values=transformer_outputs.past_key_values,
  547. hidden_states=transformer_outputs.hidden_states,
  548. attentions=transformer_outputs.attentions,
  549. )
  550. @auto_docstring(
  551. custom_intro="""
  552. The Bloom Model transformer with a sequence classification head on top (linear layer).
  553. [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  554. (e.g. GPT-1) do.
  555. Since it does classification on the last token, it requires to know the position of the last token. If a
  556. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  557. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  558. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  559. each row of the batch).
  560. """
  561. )
  562. class BloomForSequenceClassification(BloomPreTrainedModel):
  563. def __init__(self, config: BloomConfig):
  564. super().__init__(config)
  565. self.num_labels = config.num_labels
  566. self.transformer = BloomModel(config)
  567. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  568. # Initialize weights and apply final processing
  569. self.post_init()
  570. @auto_docstring
  571. def forward(
  572. self,
  573. input_ids: torch.LongTensor | None = None,
  574. past_key_values: Cache | None = None,
  575. attention_mask: torch.Tensor | None = None,
  576. inputs_embeds: torch.Tensor | None = None,
  577. labels: torch.Tensor | None = None,
  578. use_cache: bool | None = None,
  579. output_attentions: bool | None = None,
  580. output_hidden_states: bool | None = None,
  581. return_dict: bool | None = None,
  582. **kwargs,
  583. ) -> tuple[torch.Tensor] | SequenceClassifierOutputWithPast:
  584. r"""
  585. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  586. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  587. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  588. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  589. `input_ids`.
  590. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  591. [`PreTrainedTokenizer.__call__`] for details.
  592. [What are input IDs?](../glossary#input-ids)
  593. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  594. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  595. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  596. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  597. """
  598. return_dict = return_dict if return_dict is not None else self.config.return_dict
  599. transformer_outputs = self.transformer(
  600. input_ids,
  601. past_key_values=past_key_values,
  602. attention_mask=attention_mask,
  603. inputs_embeds=inputs_embeds,
  604. use_cache=use_cache,
  605. output_attentions=output_attentions,
  606. output_hidden_states=output_hidden_states,
  607. return_dict=return_dict,
  608. )
  609. hidden_states = transformer_outputs[0]
  610. logits = self.score(hidden_states)
  611. if input_ids is not None:
  612. batch_size = input_ids.shape[0]
  613. else:
  614. batch_size = inputs_embeds.shape[0]
  615. if self.config.pad_token_id is None and batch_size != 1:
  616. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  617. if self.config.pad_token_id is None:
  618. last_non_pad_token = -1
  619. elif input_ids is not None:
  620. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  621. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  622. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  623. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  624. else:
  625. last_non_pad_token = -1
  626. logger.warning_once(
  627. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  628. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  629. )
  630. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  631. loss = None
  632. if labels is not None:
  633. if self.config.problem_type is None:
  634. if self.num_labels == 1:
  635. self.config.problem_type = "regression"
  636. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  637. self.config.problem_type = "single_label_classification"
  638. else:
  639. self.config.problem_type = "multi_label_classification"
  640. if self.config.problem_type == "regression":
  641. loss_fct = MSELoss()
  642. if self.num_labels == 1:
  643. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  644. else:
  645. loss = loss_fct(pooled_logits, labels)
  646. elif self.config.problem_type == "single_label_classification":
  647. loss_fct = CrossEntropyLoss()
  648. loss = loss_fct(pooled_logits, labels)
  649. elif self.config.problem_type == "multi_label_classification":
  650. loss_fct = BCEWithLogitsLoss()
  651. loss = loss_fct(pooled_logits, labels)
  652. if not return_dict:
  653. output = (pooled_logits,) + transformer_outputs[1:]
  654. return ((loss,) + output) if loss is not None else output
  655. return SequenceClassifierOutputWithPast(
  656. loss=loss,
  657. logits=pooled_logits,
  658. past_key_values=transformer_outputs.past_key_values,
  659. hidden_states=transformer_outputs.hidden_states,
  660. attentions=transformer_outputs.attentions,
  661. )
  662. @auto_docstring
  663. class BloomForTokenClassification(BloomPreTrainedModel):
  664. def __init__(self, config: BloomConfig):
  665. super().__init__(config)
  666. self.num_labels = config.num_labels
  667. self.transformer = BloomModel(config)
  668. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  669. classifier_dropout = config.classifier_dropout
  670. elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
  671. classifier_dropout = config.hidden_dropout
  672. else:
  673. classifier_dropout = 0.1
  674. self.dropout = nn.Dropout(classifier_dropout)
  675. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  676. # Initialize weights and apply final processing
  677. self.post_init()
  678. @auto_docstring
  679. def forward(
  680. self,
  681. input_ids: torch.LongTensor | None = None,
  682. past_key_values: Cache | None = None,
  683. attention_mask: torch.Tensor | None = None,
  684. inputs_embeds: torch.Tensor | None = None,
  685. labels: torch.Tensor | None = None,
  686. use_cache: bool | None = None,
  687. output_attentions: bool | None = None,
  688. output_hidden_states: bool | None = None,
  689. return_dict: bool | None = None,
  690. **kwargs,
  691. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  692. r"""
  693. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  694. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  695. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  696. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  697. `input_ids`.
  698. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  699. [`PreTrainedTokenizer.__call__`] for details.
  700. [What are input IDs?](../glossary#input-ids)
  701. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  702. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  703. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  704. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  705. """
  706. return_dict = return_dict if return_dict is not None else self.config.return_dict
  707. transformer_outputs = self.transformer(
  708. input_ids,
  709. past_key_values=past_key_values,
  710. attention_mask=attention_mask,
  711. inputs_embeds=inputs_embeds,
  712. use_cache=use_cache,
  713. output_attentions=output_attentions,
  714. output_hidden_states=output_hidden_states,
  715. return_dict=return_dict,
  716. )
  717. hidden_states = transformer_outputs[0]
  718. hidden_states = self.dropout(hidden_states)
  719. logits = self.classifier(hidden_states)
  720. loss = None
  721. if labels is not None:
  722. # move labels to correct device
  723. labels = labels.to(logits.device)
  724. batch_size, seq_length = labels.shape
  725. loss_fct = CrossEntropyLoss()
  726. loss = loss_fct(
  727. logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
  728. )
  729. if not return_dict:
  730. output = (logits,) + transformer_outputs[2:]
  731. return ((loss,) + output) if loss is not None else output
  732. return TokenClassifierOutput(
  733. loss=loss,
  734. logits=logits,
  735. hidden_states=transformer_outputs.hidden_states,
  736. attentions=transformer_outputs.attentions,
  737. )
  738. @auto_docstring
  739. class BloomForQuestionAnswering(BloomPreTrainedModel):
  740. def __init__(self, config):
  741. super().__init__(config)
  742. self.transformer = BloomModel(config)
  743. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  744. # Initialize weights and apply final processing
  745. self.post_init()
  746. @auto_docstring
  747. def forward(
  748. self,
  749. input_ids: torch.LongTensor | None = None,
  750. attention_mask: torch.FloatTensor | None = None,
  751. inputs_embeds: torch.FloatTensor | None = None,
  752. start_positions: torch.LongTensor | None = None,
  753. end_positions: torch.LongTensor | None = None,
  754. output_attentions: bool | None = None,
  755. output_hidden_states: bool | None = None,
  756. return_dict: bool | None = None,
  757. **kwargs,
  758. ) -> tuple | QuestionAnsweringModelOutput:
  759. r"""
  760. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  761. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  762. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  763. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  764. `input_ids`.
  765. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  766. [`PreTrainedTokenizer.__call__`] for details.
  767. [What are input IDs?](../glossary#input-ids)
  768. """
  769. return_dict = return_dict if return_dict is not None else self.config.return_dict
  770. outputs = self.transformer(
  771. input_ids,
  772. attention_mask=attention_mask,
  773. inputs_embeds=inputs_embeds,
  774. output_attentions=output_attentions,
  775. output_hidden_states=output_hidden_states,
  776. return_dict=return_dict,
  777. )
  778. sequence_output = outputs[0]
  779. logits = self.qa_outputs(sequence_output)
  780. start_logits, end_logits = logits.split(1, dim=-1)
  781. start_logits = start_logits.squeeze(-1).contiguous()
  782. end_logits = end_logits.squeeze(-1).contiguous()
  783. total_loss = None
  784. if start_positions is not None and end_positions is not None:
  785. # If we are on multi-GPU, split add a dimension
  786. if len(start_positions.size()) > 1:
  787. start_positions = start_positions.squeeze(-1)
  788. if len(end_positions.size()) > 1:
  789. end_positions = end_positions.squeeze(-1)
  790. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  791. ignored_index = start_logits.size(1)
  792. start_positions = start_positions.clamp(0, ignored_index)
  793. end_positions = end_positions.clamp(0, ignored_index)
  794. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  795. start_loss = loss_fct(start_logits, start_positions)
  796. end_loss = loss_fct(end_logits, end_positions)
  797. total_loss = (start_loss + end_loss) / 2
  798. if not return_dict:
  799. output = (start_logits, end_logits) + outputs[2:]
  800. return ((total_loss,) + output) if total_loss is not None else output
  801. return QuestionAnsweringModelOutput(
  802. loss=total_loss,
  803. start_logits=start_logits,
  804. end_logits=end_logits,
  805. hidden_states=outputs.hidden_states,
  806. attentions=outputs.attentions,
  807. )
  808. __all__ = [
  809. "BloomForCausalLM",
  810. "BloomModel",
  811. "BloomPreTrainedModel",
  812. "BloomForSequenceClassification",
  813. "BloomForTokenClassification",
  814. "BloomForQuestionAnswering",
  815. ]