modeling_xglm.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. # Copyright 2021 The Fairseq Authors The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch XGLM model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
  25. from ...modeling_utils import PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import TransformersKwargs, auto_docstring, logging
  28. from ...utils.generic import merge_with_config_defaults
  29. from ...utils.output_capturing import OutputRecorder, capture_outputs
  30. from .configuration_xglm import XGLMConfig
  31. logger = logging.get_logger(__name__)
  32. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->XGLM
  33. class XGLMScaledWordEmbedding(nn.Embedding):
  34. """
  35. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  36. """
  37. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0):
  38. super().__init__(num_embeddings, embedding_dim, padding_idx)
  39. self.embed_scale = embed_scale
  40. def forward(self, input_ids: torch.Tensor):
  41. return super().forward(input_ids) * self.embed_scale
  42. class XGLMSinusoidalPositionalEmbedding(nn.Module):
  43. """This module produces sinusoidal positional embeddings of any length."""
  44. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None):
  45. super().__init__()
  46. self.offset = 2
  47. self.num_positions = num_positions
  48. self.embedding_dim = embedding_dim
  49. self.padding_idx = padding_idx
  50. self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
  51. def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
  52. emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
  53. if hasattr(self, "weights"):
  54. # in forward put the weights on the correct dtype and device of the param
  55. emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
  56. self.register_buffer("weights", emb_weights, persistent=False)
  57. @staticmethod
  58. def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
  59. """
  60. Build sinusoidal embeddings.
  61. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
  62. "Attention Is All You Need".
  63. """
  64. half_dim = embedding_dim // 2
  65. emb = math.log(10000) / (half_dim - 1)
  66. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  67. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  68. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  69. if embedding_dim % 2 == 1:
  70. # zero pad
  71. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  72. if padding_idx is not None:
  73. emb[padding_idx, :] = 0
  74. return emb.to(torch.get_default_dtype())
  75. @torch.no_grad()
  76. def forward(self, position_ids: torch.Tensor | None = None, past_key_values_length: int = 0):
  77. bsz, seq_len = position_ids.size()
  78. position_ids = position_ids + self.offset
  79. max_pos = 2 + seq_len + past_key_values_length
  80. if max_pos > self.weights.size(0):
  81. self.make_weights(max_pos, self.embedding_dim, self.padding_idx)
  82. return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
  83. class XGLMAttention(nn.Module):
  84. """Multi-headed attention from 'Attention Is All You Need' paper"""
  85. def __init__(
  86. self,
  87. embed_dim: int,
  88. num_heads: int,
  89. dropout: float | None = 0.0,
  90. is_decoder: bool | None = False,
  91. bias: bool | None = True,
  92. layer_idx: bool | None = None,
  93. ):
  94. super().__init__()
  95. self.embed_dim = embed_dim
  96. self.num_heads = num_heads
  97. self.dropout = dropout
  98. self.head_dim = embed_dim // num_heads
  99. if (self.head_dim * num_heads) != self.embed_dim:
  100. raise ValueError(
  101. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  102. f" and `num_heads`: {num_heads})."
  103. )
  104. self.scaling = self.head_dim**-0.5
  105. self.is_decoder = is_decoder
  106. self.layer_idx = layer_idx
  107. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  108. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  109. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  110. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  111. def forward(
  112. self,
  113. hidden_states: torch.Tensor,
  114. key_value_states: torch.Tensor | None = None,
  115. past_key_values: Cache | None = None,
  116. attention_mask: torch.Tensor | None = None,
  117. **kwargs: Unpack[TransformersKwargs],
  118. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  119. """Input shape: Batch x Time x Channel"""
  120. # if key_value_states are provided this layer is used as a cross-attention layer
  121. # for the decoder
  122. is_cross_attention = key_value_states is not None
  123. bsz, tgt_len, _ = hidden_states.size()
  124. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  125. # get query proj
  126. query_states = self.q_proj(hidden_states) * self.scaling
  127. is_updated = False
  128. if past_key_values is not None:
  129. if isinstance(past_key_values, EncoderDecoderCache):
  130. is_updated = past_key_values.is_updated.get(self.layer_idx)
  131. if is_cross_attention:
  132. # after the first generated id, we can subsequently re-use all key/value_states from cache
  133. curr_past_key_values = past_key_values.cross_attention_cache
  134. else:
  135. curr_past_key_values = past_key_values.self_attention_cache
  136. else:
  137. curr_past_key_values = past_key_values
  138. current_states = key_value_states if is_cross_attention else hidden_states
  139. if is_cross_attention and past_key_values is not None and is_updated:
  140. # reuse k,v, cross_attentions
  141. key_states = curr_past_key_values.layers[self.layer_idx].keys
  142. value_states = curr_past_key_values.layers[self.layer_idx].values
  143. else:
  144. key_states = self.k_proj(current_states)
  145. value_states = self.v_proj(current_states)
  146. key_states = key_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2)
  147. value_states = value_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2)
  148. if past_key_values is not None:
  149. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  150. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  151. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  152. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  153. past_key_values.is_updated[self.layer_idx] = True
  154. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  155. query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  156. query_states = query_states.reshape(*proj_shape)
  157. key_states = key_states.reshape(*proj_shape)
  158. value_states = value_states.reshape(*proj_shape)
  159. src_len = key_states.size(1)
  160. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  161. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  162. raise ValueError(
  163. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  164. f" {attn_weights.size()}"
  165. )
  166. if attention_mask is not None:
  167. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  168. raise ValueError(
  169. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  170. )
  171. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  172. attn_weights = torch.max(
  173. attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
  174. )
  175. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  176. # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
  177. if attn_weights.dtype == torch.float16:
  178. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
  179. else:
  180. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  181. # this operation is a bit awkward, but it's required to
  182. # make sure that attn_weights keeps its gradient.
  183. # In order to do so, attn_weights have to be reshaped
  184. # twice and have to be reused in the following
  185. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  186. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  187. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  188. attn_output = torch.bmm(attn_probs, value_states)
  189. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  190. raise ValueError(
  191. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  192. f" {attn_output.size()}"
  193. )
  194. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  195. attn_output = attn_output.transpose(1, 2)
  196. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  197. # partitioned across GPUs when using tensor-parallelism.
  198. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  199. attn_output = self.out_proj(attn_output)
  200. return attn_output, attn_weights_reshaped
  201. class XGLMDecoderLayer(GradientCheckpointingLayer):
  202. def __init__(self, config: XGLMConfig, layer_idx=None):
  203. super().__init__()
  204. self.embed_dim = config.d_model
  205. self.self_attn = XGLMAttention(
  206. embed_dim=self.embed_dim,
  207. num_heads=config.attention_heads,
  208. dropout=config.attention_dropout,
  209. is_decoder=True,
  210. layer_idx=layer_idx,
  211. )
  212. self.dropout = config.dropout
  213. self.activation_fn = ACT2FN[config.activation_function]
  214. self.activation_dropout = config.activation_dropout
  215. if config.add_cross_attention:
  216. self.encoder_attn = XGLMAttention(
  217. embed_dim=self.embed_dim,
  218. num_heads=config.attention_heads,
  219. dropout=config.attention_dropout,
  220. is_decoder=True,
  221. layer_idx=layer_idx,
  222. )
  223. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  224. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  225. self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
  226. self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
  227. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  228. # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward
  229. def forward(
  230. self,
  231. hidden_states: torch.Tensor,
  232. attention_mask: torch.Tensor | None = None,
  233. encoder_hidden_states: torch.Tensor | None = None,
  234. encoder_attention_mask: torch.Tensor | None = None,
  235. past_key_values: Cache | None = None,
  236. use_cache: bool | None = True,
  237. **kwargs: Unpack[TransformersKwargs],
  238. ) -> torch.Tensor:
  239. """
  240. Args:
  241. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  242. attention_mask (`torch.FloatTensor`): attention mask of size
  243. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  244. encoder_hidden_states (`torch.FloatTensor`):
  245. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  246. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  247. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  248. past_key_values (`Cache`): cached past key and value projection states
  249. """
  250. residual = hidden_states
  251. hidden_states = self.self_attn_layer_norm(hidden_states)
  252. # Self Attention
  253. hidden_states, _ = self.self_attn(
  254. hidden_states,
  255. past_key_values=past_key_values,
  256. attention_mask=attention_mask,
  257. **kwargs,
  258. )
  259. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  260. hidden_states = residual + hidden_states
  261. # Cross-Attention Block
  262. if encoder_hidden_states is not None:
  263. residual = hidden_states
  264. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  265. hidden_states, _ = self.encoder_attn(
  266. hidden_states,
  267. key_value_states=encoder_hidden_states,
  268. attention_mask=encoder_attention_mask,
  269. past_key_values=past_key_values,
  270. **kwargs,
  271. )
  272. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  273. hidden_states = residual + hidden_states
  274. # Fully Connected
  275. residual = hidden_states
  276. hidden_states = self.final_layer_norm(hidden_states)
  277. hidden_states = self.activation_fn(self.fc1(hidden_states))
  278. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  279. hidden_states = self.fc2(hidden_states)
  280. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  281. hidden_states = residual + hidden_states
  282. return hidden_states
  283. @auto_docstring
  284. class XGLMPreTrainedModel(PreTrainedModel):
  285. config: XGLMConfig
  286. base_model_prefix = "model"
  287. supports_gradient_checkpointing = True
  288. _no_split_modules = ["XGLMDecoderLayer"]
  289. def _init_weights(self, module):
  290. super()._init_weights(module)
  291. if isinstance(module, XGLMSinusoidalPositionalEmbedding):
  292. emb_weights = module.get_embedding(
  293. module.num_positions + module.offset, module.embedding_dim, module.padding_idx
  294. )
  295. init.copy_(module.weights, emb_weights)
  296. @auto_docstring
  297. class XGLMModel(XGLMPreTrainedModel):
  298. _can_record_outputs = {
  299. "hidden_states": XGLMDecoderLayer,
  300. "attentions": OutputRecorder(XGLMAttention, index=1, layer_name="self_attn"),
  301. "cross_attentions": OutputRecorder(XGLMAttention, index=1, layer_name="encoder_attn"),
  302. }
  303. def __init__(self, config: XGLMConfig):
  304. super().__init__(config)
  305. self.dropout = config.dropout
  306. self.layerdrop = config.layerdrop
  307. self.padding_idx = config.pad_token_id
  308. self.max_target_positions = config.max_position_embeddings
  309. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  310. self.embed_tokens = XGLMScaledWordEmbedding(
  311. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  312. )
  313. self.embed_positions = XGLMSinusoidalPositionalEmbedding(
  314. config.max_position_embeddings,
  315. config.d_model,
  316. config.pad_token_id,
  317. )
  318. self.layers = nn.ModuleList([XGLMDecoderLayer(config, layer_idx=i) for i in range(config.num_layers)])
  319. self.layer_norm = nn.LayerNorm(config.d_model)
  320. self.gradient_checkpointing = False
  321. # Initialize weights and apply final processing
  322. self.post_init()
  323. @merge_with_config_defaults
  324. @capture_outputs
  325. @auto_docstring
  326. def forward(
  327. self,
  328. input_ids: torch.Tensor | None = None,
  329. attention_mask: torch.Tensor | None = None,
  330. position_ids: torch.Tensor | None = None,
  331. encoder_hidden_states: torch.Tensor | None = None,
  332. encoder_attention_mask: torch.Tensor | None = None,
  333. past_key_values: Cache | None = None,
  334. inputs_embeds: torch.Tensor | None = None,
  335. use_cache: bool | None = None,
  336. **kwargs: Unpack[TransformersKwargs],
  337. ) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
  338. r"""
  339. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  340. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
  341. the decoder.
  342. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  343. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  344. selected in `[0, 1]`:
  345. - 1 for tokens that are **not masked**,
  346. - 0 for tokens that are **masked**.
  347. [What are attention masks?](../glossary#attention-mask)
  348. """
  349. if (input_ids is None) ^ (inputs_embeds is not None):
  350. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  351. if inputs_embeds is None:
  352. inputs_embeds = self.embed_tokens(input_ids)
  353. # initialize `past_key_values`
  354. if use_cache and past_key_values is None:
  355. past_key_values = (
  356. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  357. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  358. else DynamicCache(config=self.config)
  359. )
  360. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  361. attention_mask = create_causal_mask(
  362. config=self.config,
  363. inputs_embeds=inputs_embeds,
  364. attention_mask=attention_mask,
  365. past_key_values=past_key_values,
  366. )
  367. if position_ids is None:
  368. position_ids = torch.arange(
  369. past_key_values_length,
  370. inputs_embeds.shape[1] + past_key_values_length,
  371. dtype=torch.long,
  372. device=input_ids.device if input_ids is not None else inputs_embeds.device,
  373. )
  374. position_ids = position_ids.unsqueeze(0)
  375. # expand encoder attention mask
  376. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  377. encoder_attention_mask = create_bidirectional_mask(
  378. config=self.config,
  379. inputs_embeds=inputs_embeds,
  380. attention_mask=encoder_attention_mask,
  381. encoder_hidden_states=encoder_hidden_states,
  382. )
  383. hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length).to(
  384. inputs_embeds.device
  385. )
  386. hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
  387. for idx, decoder_layer in enumerate(self.layers):
  388. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  389. if self.training:
  390. dropout_probability = torch.rand([])
  391. if dropout_probability < self.layerdrop:
  392. continue
  393. hidden_states = decoder_layer(
  394. hidden_states,
  395. attention_mask,
  396. encoder_hidden_states, # as a positional argument for gradient checkpointing
  397. encoder_attention_mask=encoder_attention_mask,
  398. past_key_values=past_key_values,
  399. use_cache=use_cache,
  400. **kwargs,
  401. )
  402. hidden_states = self.layer_norm(hidden_states)
  403. return BaseModelOutputWithPastAndCrossAttentions(
  404. last_hidden_state=hidden_states,
  405. past_key_values=past_key_values,
  406. )
  407. @auto_docstring(
  408. custom_intro="""
  409. The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
  410. embeddings).
  411. """
  412. )
  413. class XGLMForCausalLM(XGLMPreTrainedModel, GenerationMixin):
  414. base_model_prefix = "model"
  415. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  416. def __init__(self, config):
  417. super().__init__(config)
  418. self.model = XGLMModel(config)
  419. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  420. # Initialize weights and apply final processing
  421. self.post_init()
  422. @merge_with_config_defaults
  423. @capture_outputs
  424. @auto_docstring
  425. def forward(
  426. self,
  427. input_ids: torch.Tensor | None = None,
  428. attention_mask: torch.Tensor | None = None,
  429. position_ids: torch.Tensor | None = None,
  430. encoder_hidden_states: torch.Tensor | None = None,
  431. encoder_attention_mask: torch.Tensor | None = None,
  432. past_key_values: Cache | None = None,
  433. inputs_embeds: torch.Tensor | None = None,
  434. labels: torch.Tensor | None = None,
  435. use_cache: bool | None = None,
  436. logits_to_keep: int | torch.Tensor = 0,
  437. **kwargs: Unpack[TransformersKwargs],
  438. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  439. r"""
  440. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  441. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
  442. the decoder.
  443. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  444. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  445. selected in `[0, 1]`:
  446. - 1 for tokens that are **not masked**,
  447. - 0 for tokens that are **masked**.
  448. [What are attention masks?](../glossary#attention-mask)
  449. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  450. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  451. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  452. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  453. """
  454. outputs: BaseModelOutputWithPastAndCrossAttentions = self.model(
  455. input_ids=input_ids,
  456. attention_mask=attention_mask,
  457. position_ids=position_ids,
  458. encoder_hidden_states=encoder_hidden_states,
  459. encoder_attention_mask=encoder_attention_mask,
  460. past_key_values=past_key_values,
  461. inputs_embeds=inputs_embeds,
  462. use_cache=use_cache,
  463. **kwargs,
  464. )
  465. hidden_states = outputs.last_hidden_state
  466. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  467. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  468. logits = self.lm_head(hidden_states[:, slice_indices, :])
  469. loss = None
  470. if labels is not None:
  471. loss = self.loss_function(
  472. logits,
  473. labels,
  474. vocab_size=self.config.vocab_size,
  475. pad_token_id=self.config.pad_token_id,
  476. )
  477. return CausalLMOutputWithCrossAttentions(
  478. loss=loss,
  479. logits=logits,
  480. past_key_values=outputs.past_key_values,
  481. hidden_states=outputs.hidden_states,
  482. attentions=outputs.attentions,
  483. cross_attentions=outputs.cross_attentions,
  484. )
  485. __all__ = ["XGLMForCausalLM", "XGLMModel", "XGLMPreTrainedModel"]