modeling_pegasus.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132
  1. # Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch PEGASUS model."""
  15. import copy
  16. import math
  17. from collections.abc import Callable
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from torch.nn import CrossEntropyLoss
  22. from ... import initialization as init
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  25. from ...generation import GenerationMixin
  26. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  27. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import (
  30. BaseModelOutput,
  31. BaseModelOutputWithPastAndCrossAttentions,
  32. CausalLMOutputWithCrossAttentions,
  33. Seq2SeqLMOutput,
  34. Seq2SeqModelOutput,
  35. )
  36. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  37. from ...processing_utils import Unpack
  38. from ...utils import (
  39. TransformersKwargs,
  40. auto_docstring,
  41. can_return_tuple,
  42. is_torchdynamo_compiling,
  43. logging,
  44. )
  45. from ...utils.generic import merge_with_config_defaults
  46. from ...utils.output_capturing import OutputRecorder, capture_outputs
  47. from .configuration_pegasus import PegasusConfig
  48. logger = logging.get_logger(__name__)
  49. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  50. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  51. """
  52. Shift input ids one token to the right.
  53. """
  54. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  55. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  56. shifted_input_ids[:, 0] = decoder_start_token_id
  57. if pad_token_id is None:
  58. raise ValueError("self.model.config.pad_token_id has to be defined.")
  59. # replace possible -100 values in labels by `pad_token_id`
  60. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  61. return shifted_input_ids
  62. # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
  63. class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
  64. """This module produces sinusoidal positional embeddings of any length."""
  65. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None) -> None:
  66. super().__init__(num_positions, embedding_dim, _freeze=True)
  67. def create_weight(self):
  68. """
  69. Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
  70. the 2nd half of the vector. [dim // 2:]
  71. """
  72. n_pos, dim = self.weight.shape
  73. position_enc = np.array(
  74. [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
  75. )
  76. out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False)
  77. sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
  78. out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  79. out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
  80. return out
  81. @torch.no_grad()
  82. def forward(
  83. self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None
  84. ) -> torch.Tensor:
  85. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  86. if position_ids is None:
  87. bsz, seq_len = input_ids_shape[:2]
  88. position_ids = torch.arange(
  89. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  90. )
  91. return super().forward(position_ids)
  92. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  93. def eager_attention_forward(
  94. module: nn.Module,
  95. query: torch.Tensor,
  96. key: torch.Tensor,
  97. value: torch.Tensor,
  98. attention_mask: torch.Tensor | None,
  99. scaling: float | None = None,
  100. dropout: float = 0.0,
  101. **kwargs: Unpack[TransformersKwargs],
  102. ):
  103. if scaling is None:
  104. scaling = query.size(-1) ** -0.5
  105. # Take the dot product between "query" and "key" to get the raw attention scores.
  106. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  107. if attention_mask is not None:
  108. attn_weights = attn_weights + attention_mask
  109. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  110. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  111. attn_output = torch.matmul(attn_weights, value)
  112. attn_output = attn_output.transpose(1, 2).contiguous()
  113. return attn_output, attn_weights
  114. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus
  115. class PegasusAttention(nn.Module):
  116. """Multi-headed attention from 'Attention Is All You Need' paper"""
  117. def __init__(
  118. self,
  119. embed_dim: int,
  120. num_heads: int,
  121. dropout: float = 0.0,
  122. is_decoder: bool = False,
  123. bias: bool = True,
  124. is_causal: bool = False,
  125. config: PegasusConfig | None = None,
  126. layer_idx: int | None = None,
  127. ):
  128. super().__init__()
  129. self.embed_dim = embed_dim
  130. self.num_heads = num_heads
  131. self.dropout = dropout
  132. self.head_dim = embed_dim // num_heads
  133. self.config = config
  134. if (self.head_dim * num_heads) != self.embed_dim:
  135. raise ValueError(
  136. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  137. f" and `num_heads`: {num_heads})."
  138. )
  139. self.scaling = self.head_dim**-0.5
  140. self.is_decoder = is_decoder
  141. self.is_causal = is_causal
  142. self.layer_idx = layer_idx
  143. if layer_idx is None and self.is_decoder:
  144. logger.warning_once(
  145. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  146. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  147. "when creating this class."
  148. )
  149. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  150. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  151. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  152. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  153. def forward(
  154. self,
  155. hidden_states: torch.Tensor,
  156. key_value_states: torch.Tensor | None = None,
  157. past_key_values: Cache | None = None,
  158. attention_mask: torch.Tensor | None = None,
  159. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  160. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  161. **kwargs: Unpack[FlashAttentionKwargs],
  162. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  163. """Input shape: Batch x Time x Channel"""
  164. # if key_value_states are provided this layer is used as a cross-attention layer
  165. # for the decoder
  166. is_cross_attention = key_value_states is not None
  167. # determine input shapes
  168. input_shape = hidden_states.shape[:-1]
  169. hidden_shape = (*input_shape, -1, self.head_dim)
  170. # get query proj
  171. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  172. is_updated = False
  173. if past_key_values is not None:
  174. if isinstance(past_key_values, EncoderDecoderCache):
  175. is_updated = past_key_values.is_updated.get(self.layer_idx)
  176. if is_cross_attention:
  177. # after the first generated id, we can subsequently re-use all key/value_states from cache
  178. curr_past_key_values = past_key_values.cross_attention_cache
  179. else:
  180. curr_past_key_values = past_key_values.self_attention_cache
  181. else:
  182. curr_past_key_values = past_key_values
  183. current_states = key_value_states if is_cross_attention else hidden_states
  184. if is_cross_attention and past_key_values is not None and is_updated:
  185. # reuse k,v, cross_attentions
  186. key_states = curr_past_key_values.layers[self.layer_idx].keys
  187. value_states = curr_past_key_values.layers[self.layer_idx].values
  188. else:
  189. key_states = self.k_proj(current_states)
  190. value_states = self.v_proj(current_states)
  191. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  192. key_states = key_states.view(kv_shape).transpose(1, 2)
  193. value_states = value_states.view(kv_shape).transpose(1, 2)
  194. if past_key_values is not None:
  195. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  196. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  197. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  198. past_key_values.is_updated[self.layer_idx] = True
  199. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  200. self.config._attn_implementation, eager_attention_forward
  201. )
  202. attn_output, attn_weights = attention_interface(
  203. self,
  204. query_states,
  205. key_states,
  206. value_states,
  207. attention_mask,
  208. dropout=0.0 if not self.training else self.dropout,
  209. scaling=self.scaling,
  210. **kwargs,
  211. )
  212. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  213. attn_output = self.out_proj(attn_output)
  214. return attn_output, attn_weights
  215. # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
  216. class PegasusEncoderLayer(GradientCheckpointingLayer):
  217. def __init__(self, config: PegasusConfig):
  218. super().__init__()
  219. self.embed_dim = config.d_model
  220. self.self_attn = PegasusAttention(
  221. embed_dim=self.embed_dim,
  222. num_heads=config.encoder_attention_heads,
  223. dropout=config.attention_dropout,
  224. config=config,
  225. )
  226. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  227. self.dropout = config.dropout
  228. self.activation_fn = ACT2FN[config.activation_function]
  229. self.activation_dropout = config.activation_dropout
  230. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  231. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  232. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  233. def forward(
  234. self,
  235. hidden_states: torch.Tensor,
  236. attention_mask: torch.Tensor,
  237. **kwargs: Unpack[TransformersKwargs],
  238. ) -> torch.Tensor:
  239. """
  240. Args:
  241. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  242. attention_mask (`torch.FloatTensor`): attention mask of size
  243. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  244. """
  245. residual = hidden_states
  246. hidden_states = self.self_attn_layer_norm(hidden_states)
  247. hidden_states, _ = self.self_attn(
  248. hidden_states=hidden_states,
  249. attention_mask=attention_mask,
  250. **kwargs,
  251. )
  252. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  253. hidden_states = residual + hidden_states
  254. residual = hidden_states
  255. hidden_states = self.final_layer_norm(hidden_states)
  256. hidden_states = self.activation_fn(self.fc1(hidden_states))
  257. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  258. hidden_states = self.fc2(hidden_states)
  259. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  260. hidden_states = residual + hidden_states
  261. if hidden_states.dtype == torch.float16:
  262. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  263. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  264. return hidden_states
  265. # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS
  266. class PegasusDecoderLayer(GradientCheckpointingLayer):
  267. def __init__(self, config: PegasusConfig, layer_idx: int | None = None):
  268. super().__init__()
  269. self.embed_dim = config.d_model
  270. self.self_attn = PegasusAttention(
  271. embed_dim=self.embed_dim,
  272. num_heads=config.decoder_attention_heads,
  273. dropout=config.attention_dropout,
  274. is_decoder=True,
  275. is_causal=True,
  276. config=config,
  277. layer_idx=layer_idx,
  278. )
  279. self.dropout = config.dropout
  280. self.activation_fn = ACT2FN[config.activation_function]
  281. self.activation_dropout = config.activation_dropout
  282. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  283. self.encoder_attn = PegasusAttention(
  284. self.embed_dim,
  285. config.decoder_attention_heads,
  286. dropout=config.attention_dropout,
  287. is_decoder=True,
  288. config=config,
  289. layer_idx=layer_idx,
  290. )
  291. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  292. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  293. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  294. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  295. def forward(
  296. self,
  297. hidden_states: torch.Tensor,
  298. attention_mask: torch.Tensor | None = None,
  299. encoder_hidden_states: torch.Tensor | None = None,
  300. encoder_attention_mask: torch.Tensor | None = None,
  301. past_key_values: Cache | None = None,
  302. use_cache: bool | None = True,
  303. **kwargs: Unpack[TransformersKwargs],
  304. ) -> torch.Tensor:
  305. """
  306. Args:
  307. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  308. attention_mask (`torch.FloatTensor`): attention mask of size
  309. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  310. encoder_hidden_states (`torch.FloatTensor`):
  311. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  312. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  313. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  314. past_key_values (`Cache`): cached past key and value projection states
  315. """
  316. residual = hidden_states
  317. hidden_states = self.self_attn_layer_norm(hidden_states)
  318. # Self Attention
  319. hidden_states, _ = self.self_attn(
  320. hidden_states=hidden_states,
  321. past_key_values=past_key_values,
  322. attention_mask=attention_mask,
  323. **kwargs,
  324. )
  325. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  326. hidden_states = residual + hidden_states
  327. # Cross-Attention Block
  328. if encoder_hidden_states is not None:
  329. residual = hidden_states
  330. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  331. hidden_states, _ = self.encoder_attn(
  332. hidden_states=hidden_states,
  333. key_value_states=encoder_hidden_states,
  334. attention_mask=encoder_attention_mask,
  335. past_key_values=past_key_values,
  336. **kwargs,
  337. )
  338. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  339. hidden_states = residual + hidden_states
  340. # Fully Connected
  341. residual = hidden_states
  342. hidden_states = self.final_layer_norm(hidden_states)
  343. hidden_states = self.activation_fn(self.fc1(hidden_states))
  344. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  345. hidden_states = self.fc2(hidden_states)
  346. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  347. hidden_states = residual + hidden_states
  348. return hidden_states
  349. @auto_docstring
  350. class PegasusPreTrainedModel(PreTrainedModel):
  351. config: PegasusConfig
  352. base_model_prefix = "model"
  353. supports_gradient_checkpointing = True
  354. _supports_flash_attn = True
  355. _supports_sdpa = True
  356. _supports_flex_attn = True
  357. _can_compile_fullgraph = True
  358. @torch.no_grad()
  359. def _init_weights(self, module):
  360. super()._init_weights(module)
  361. if isinstance(module, PegasusSinusoidalPositionalEmbedding):
  362. init.copy_(module.weight, module.create_weight())
  363. elif isinstance(module, PegasusForConditionalGeneration):
  364. init.zeros_(module.final_logits_bias)
  365. class PegasusEncoder(PegasusPreTrainedModel):
  366. """
  367. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  368. [`PegasusEncoderLayer`].
  369. Args:
  370. config: PegasusConfig
  371. embed_tokens (nn.Embedding): output embedding
  372. """
  373. _can_record_outputs = {
  374. "hidden_states": PegasusEncoderLayer,
  375. "attentions": PegasusAttention,
  376. }
  377. def __init__(self, config: PegasusConfig):
  378. super().__init__(config)
  379. self.dropout = config.dropout
  380. self.layerdrop = config.encoder_layerdrop
  381. embed_dim = config.d_model
  382. self.padding_idx = config.pad_token_id
  383. self.max_source_positions = config.max_position_embeddings
  384. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  385. self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
  386. self.embed_positions = PegasusSinusoidalPositionalEmbedding(
  387. config.max_position_embeddings,
  388. embed_dim,
  389. self.padding_idx,
  390. )
  391. self.layers = nn.ModuleList([PegasusEncoderLayer(config) for _ in range(config.encoder_layers)])
  392. self.layer_norm = nn.LayerNorm(config.d_model)
  393. self.gradient_checkpointing = False
  394. # Initialize weights and apply final processing
  395. self.post_init()
  396. def resize_position_embeddings(self, new_num_position_embeddings: int):
  397. """
  398. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  399. config.max_position_embeddings`.
  400. Arguments:
  401. new_num_position_embeddings (`int`):
  402. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  403. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  404. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  405. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  406. will remove vectors from the end.
  407. """
  408. logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
  409. self.config.max_position_embeddings = new_num_position_embeddings
  410. self.embed_positions = PegasusSinusoidalPositionalEmbedding(
  411. self.config.max_position_embeddings,
  412. self.config.d_model,
  413. self.padding_idx,
  414. )
  415. init.copy_(self.embed_positions.weight, self.embed_positions.create_weight())
  416. self.embed_positions.to(self.device)
  417. def get_position_embeddings(self) -> nn.Embedding:
  418. """
  419. Returns the position embeddings matrix
  420. """
  421. return self.embed_positions
  422. @merge_with_config_defaults
  423. @capture_outputs
  424. @auto_docstring
  425. def forward(
  426. self,
  427. input_ids=None,
  428. attention_mask=None,
  429. inputs_embeds=None,
  430. **kwargs: Unpack[TransformersKwargs],
  431. ) -> BaseModelOutput:
  432. if (input_ids is None) ^ (inputs_embeds is not None):
  433. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  434. if inputs_embeds is None:
  435. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  436. input_shape = inputs_embeds.shape[:-1]
  437. embed_pos = self.embed_positions(input_shape)
  438. hidden_states = inputs_embeds + embed_pos
  439. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  440. attention_mask = create_bidirectional_mask(
  441. config=self.config,
  442. inputs_embeds=inputs_embeds,
  443. attention_mask=attention_mask,
  444. )
  445. for idx, encoder_layer in enumerate(self.layers):
  446. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  447. to_drop = False
  448. if self.training:
  449. dropout_probability = torch.rand([])
  450. if dropout_probability < self.layerdrop: # skip the layer
  451. to_drop = True
  452. if not to_drop:
  453. hidden_states = encoder_layer(
  454. hidden_states,
  455. attention_mask,
  456. **kwargs,
  457. )
  458. hidden_states = self.layer_norm(hidden_states)
  459. return BaseModelOutput(
  460. last_hidden_state=hidden_states,
  461. )
  462. class PegasusDecoder(PegasusPreTrainedModel):
  463. """
  464. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`]
  465. Args:
  466. config: PegasusConfig
  467. embed_tokens (nn.Embedding): output embedding
  468. """
  469. _can_record_outputs = {
  470. "hidden_states": PegasusDecoderLayer,
  471. "attentions": OutputRecorder(PegasusAttention, index=1, layer_name="self_attn"),
  472. "cross_attentions": OutputRecorder(PegasusAttention, index=1, layer_name="encoder_attn"),
  473. }
  474. def __init__(self, config: PegasusConfig):
  475. super().__init__(config)
  476. self.dropout = config.dropout
  477. self.layerdrop = config.decoder_layerdrop
  478. self.padding_idx = config.pad_token_id
  479. self.max_target_positions = config.max_position_embeddings
  480. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  481. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  482. self.embed_positions = PegasusSinusoidalPositionalEmbedding(
  483. config.max_position_embeddings,
  484. config.d_model,
  485. self.padding_idx,
  486. )
  487. self.layers = nn.ModuleList([PegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  488. self.layer_norm = nn.LayerNorm(config.d_model)
  489. self.gradient_checkpointing = False
  490. # Initialize weights and apply final processing
  491. self.post_init()
  492. def resize_position_embeddings(self, new_num_position_embeddings: int):
  493. """
  494. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  495. config.max_position_embeddings`.
  496. Arguments:
  497. new_num_position_embeddings (`int`):
  498. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  499. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  500. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  501. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  502. will remove vectors from the end.
  503. """
  504. logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
  505. self.config.max_position_embeddings = new_num_position_embeddings
  506. self.embed_positions = PegasusSinusoidalPositionalEmbedding(
  507. self.config.max_position_embeddings,
  508. self.config.d_model,
  509. self.padding_idx,
  510. )
  511. init.copy_(self.embed_positions.weight, self.embed_positions.create_weight())
  512. self.embed_positions.to(self.device)
  513. def get_position_embeddings(self) -> nn.Embedding:
  514. """
  515. Returns the position embeddings matrix
  516. """
  517. return self.embed_positions
  518. @merge_with_config_defaults
  519. @capture_outputs
  520. @auto_docstring
  521. def forward(
  522. self,
  523. input_ids=None,
  524. attention_mask=None,
  525. encoder_hidden_states=None,
  526. encoder_attention_mask=None,
  527. past_key_values=None,
  528. inputs_embeds=None,
  529. use_cache=None,
  530. **kwargs: Unpack[TransformersKwargs],
  531. ) -> BaseModelOutputWithPastAndCrossAttentions:
  532. if (input_ids is None) ^ (inputs_embeds is not None):
  533. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  534. if inputs_embeds is None:
  535. inputs_embeds = self.embed_tokens(input_ids)
  536. # important to apply scale outside of `if` in case users pass `embeds`
  537. inputs_embeds = inputs_embeds * self.embed_scale
  538. # initialize `past_key_values`
  539. if use_cache and past_key_values is None:
  540. past_key_values = (
  541. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  542. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  543. else DynamicCache(config=self.config)
  544. )
  545. batch_size, seq_length = inputs_embeds.size()[:-1]
  546. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  547. position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
  548. if attention_mask is None and not is_torchdynamo_compiling():
  549. # required mask seq length can be calculated via length of past cache
  550. mask_seq_length = past_key_values_length + seq_length
  551. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  552. self_attn_cache = (
  553. past_key_values.self_attention_cache
  554. if isinstance(past_key_values, EncoderDecoderCache)
  555. else past_key_values
  556. )
  557. causal_mask = create_causal_mask(
  558. config=self.config,
  559. inputs_embeds=inputs_embeds,
  560. attention_mask=attention_mask,
  561. past_key_values=self_attn_cache,
  562. )
  563. encoder_attention_mask = create_bidirectional_mask(
  564. config=self.config,
  565. inputs_embeds=inputs_embeds,
  566. attention_mask=encoder_attention_mask,
  567. encoder_hidden_states=encoder_hidden_states,
  568. )
  569. # embed positions
  570. positions = self.embed_positions((batch_size, seq_length), past_key_values_length, position_ids=position_ids)
  571. hidden_states = inputs_embeds + positions
  572. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  573. for idx, decoder_layer in enumerate(self.layers):
  574. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  575. if self.training:
  576. dropout_probability = torch.rand([])
  577. if dropout_probability < self.layerdrop:
  578. continue
  579. hidden_states = decoder_layer(
  580. hidden_states,
  581. causal_mask,
  582. encoder_hidden_states, # as a positional argument for gradient checkpointing
  583. encoder_attention_mask=encoder_attention_mask,
  584. past_key_values=past_key_values,
  585. use_cache=use_cache,
  586. **kwargs,
  587. )
  588. hidden_states = self.layer_norm(hidden_states)
  589. return BaseModelOutputWithPastAndCrossAttentions(
  590. last_hidden_state=hidden_states,
  591. past_key_values=past_key_values,
  592. )
  593. @auto_docstring
  594. class PegasusModel(PegasusPreTrainedModel):
  595. _tied_weights_keys = {
  596. "decoder.embed_tokens.weight": "shared.weight",
  597. "encoder.embed_tokens.weight": "shared.weight",
  598. }
  599. def __init__(self, config: PegasusConfig):
  600. super().__init__(config)
  601. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  602. self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  603. self.encoder = PegasusEncoder(config)
  604. self.decoder = PegasusDecoder(config)
  605. # Initialize weights and apply final processing
  606. self.post_init()
  607. def get_input_embeddings(self):
  608. return self.shared
  609. def set_input_embeddings(self, value):
  610. self.shared = value
  611. self.encoder.embed_tokens = self.shared
  612. self.decoder.embed_tokens = self.shared
  613. def resize_position_embeddings(self, new_num_position_embeddings: int):
  614. """
  615. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  616. config.max_position_embeddings`.
  617. Arguments:
  618. new_num_position_embeddings (`int`):
  619. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  620. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  621. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  622. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  623. will remove vectors from the end.
  624. """
  625. self.config.max_position_embeddings = new_num_position_embeddings
  626. self.encoder.resize_position_embeddings(new_num_position_embeddings)
  627. self.decoder.resize_position_embeddings(new_num_position_embeddings)
  628. def get_position_embeddings(self) -> tuple[nn.Embedding]:
  629. """
  630. Returns the position embeddings matrix
  631. """
  632. return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())
  633. @can_return_tuple
  634. @auto_docstring
  635. def forward(
  636. self,
  637. input_ids: torch.Tensor | None = None,
  638. attention_mask: torch.Tensor | None = None,
  639. decoder_input_ids: torch.Tensor | None = None,
  640. decoder_attention_mask: torch.Tensor | None = None,
  641. encoder_outputs: tuple[torch.FloatTensor] | None = None,
  642. past_key_values: Cache | None = None,
  643. inputs_embeds: torch.Tensor | None = None,
  644. decoder_inputs_embeds: torch.Tensor | None = None,
  645. use_cache: bool | None = None,
  646. **kwargs: Unpack[TransformersKwargs],
  647. ) -> tuple | Seq2SeqModelOutput:
  648. r"""
  649. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  650. Indices of decoder input sequence tokens in the vocabulary.
  651. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  652. [`PreTrainedTokenizer.__call__`] for details.
  653. [What are decoder input IDs?](../glossary#decoder-input-ids)
  654. Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  655. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  656. `past_key_values`).
  657. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  658. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  659. be used by default.
  660. Example:
  661. ```python
  662. >>> from transformers import AutoTokenizer, PegasusModel
  663. >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
  664. >>> model = PegasusModel.from_pretrained("google/pegasus-large")
  665. >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
  666. >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt")
  667. >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids)
  668. >>> last_hidden_states = outputs.last_hidden_state
  669. >>> list(last_hidden_states.shape)
  670. [1, 4, 1024]
  671. ```"""
  672. if encoder_outputs is None:
  673. encoder_outputs = self.encoder(
  674. input_ids=input_ids,
  675. attention_mask=attention_mask,
  676. inputs_embeds=inputs_embeds,
  677. **kwargs,
  678. )
  679. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
  680. elif not isinstance(encoder_outputs, BaseModelOutput):
  681. encoder_outputs = BaseModelOutput(
  682. last_hidden_state=encoder_outputs[0],
  683. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  684. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  685. )
  686. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  687. decoder_outputs = self.decoder(
  688. input_ids=decoder_input_ids,
  689. attention_mask=decoder_attention_mask,
  690. encoder_hidden_states=encoder_outputs[0],
  691. encoder_attention_mask=attention_mask,
  692. past_key_values=past_key_values,
  693. inputs_embeds=decoder_inputs_embeds,
  694. use_cache=use_cache,
  695. **kwargs,
  696. )
  697. return Seq2SeqModelOutput(
  698. last_hidden_state=decoder_outputs.last_hidden_state,
  699. past_key_values=decoder_outputs.past_key_values,
  700. decoder_hidden_states=decoder_outputs.hidden_states,
  701. decoder_attentions=decoder_outputs.attentions,
  702. cross_attentions=decoder_outputs.cross_attentions,
  703. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  704. encoder_hidden_states=encoder_outputs.hidden_states,
  705. encoder_attentions=encoder_outputs.attentions,
  706. )
  707. @auto_docstring(
  708. custom_intro="""
  709. The PEGASUS Model with a language modeling head. Can be used for summarization.
  710. """
  711. )
  712. class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin):
  713. base_model_prefix = "model"
  714. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  715. _tied_weights_keys = {
  716. "lm_head.weight": "model.shared.weight",
  717. }
  718. def __init__(self, config: PegasusConfig):
  719. super().__init__(config)
  720. self.model = PegasusModel(config)
  721. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  722. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  723. # Initialize weights and apply final processing
  724. self.post_init()
  725. def resize_token_embeddings(
  726. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  727. ) -> nn.Embedding:
  728. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  729. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  730. return new_embeddings
  731. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  732. old_num_tokens = self.final_logits_bias.shape[-1]
  733. if new_num_tokens <= old_num_tokens:
  734. new_bias = self.final_logits_bias[:, :new_num_tokens]
  735. else:
  736. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  737. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  738. self.register_buffer("final_logits_bias", new_bias)
  739. def resize_position_embeddings(self, new_num_position_embeddings: int):
  740. """
  741. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  742. config.max_position_embeddings`.
  743. Arguments:
  744. new_num_position_embeddings (`int`):
  745. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  746. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  747. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  748. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  749. will remove vectors from the end.
  750. """
  751. self.config.max_position_embeddings = new_num_position_embeddings
  752. self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
  753. self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
  754. def get_position_embeddings(self) -> tuple[nn.Embedding]:
  755. """
  756. Returns the position embeddings matrix
  757. """
  758. return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())
  759. @can_return_tuple
  760. @auto_docstring
  761. def forward(
  762. self,
  763. input_ids: torch.Tensor | None = None,
  764. attention_mask: torch.Tensor | None = None,
  765. decoder_input_ids: torch.Tensor | None = None,
  766. decoder_attention_mask: torch.Tensor | None = None,
  767. encoder_outputs: tuple[torch.FloatTensor] | None = None,
  768. past_key_values: Cache | None = None,
  769. inputs_embeds: torch.Tensor | None = None,
  770. decoder_inputs_embeds: torch.Tensor | None = None,
  771. labels: torch.Tensor | None = None,
  772. use_cache: bool | None = None,
  773. **kwargs: Unpack[TransformersKwargs],
  774. ) -> tuple | Seq2SeqLMOutput:
  775. r"""
  776. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  777. Indices of decoder input sequence tokens in the vocabulary.
  778. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  779. [`PreTrainedTokenizer.__call__`] for details.
  780. [What are decoder input IDs?](../glossary#decoder-input-ids)
  781. Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  782. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  783. `past_key_values`).
  784. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  785. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  786. be used by default.
  787. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  788. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  789. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  790. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  791. Example Summarization:
  792. ```python
  793. >>> from transformers import AutoTokenizer, PegasusForConditionalGeneration
  794. >>> model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
  795. >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
  796. >>> ARTICLE_TO_SUMMARIZE = (
  797. ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
  798. ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
  799. ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
  800. ... )
  801. >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt")
  802. >>> # Generate Summary
  803. >>> summary_ids = model.generate(inputs["input_ids"])
  804. >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  805. "California's largest electricity provider has turned off power to hundreds of thousands of customers."
  806. ```
  807. """
  808. if labels is not None:
  809. if use_cache:
  810. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  811. use_cache = False
  812. if decoder_input_ids is None and decoder_inputs_embeds is None:
  813. decoder_input_ids = shift_tokens_right(
  814. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  815. )
  816. outputs: Seq2SeqModelOutput = self.model(
  817. input_ids,
  818. attention_mask=attention_mask,
  819. decoder_input_ids=decoder_input_ids,
  820. encoder_outputs=encoder_outputs,
  821. decoder_attention_mask=decoder_attention_mask,
  822. past_key_values=past_key_values,
  823. inputs_embeds=inputs_embeds,
  824. decoder_inputs_embeds=decoder_inputs_embeds,
  825. use_cache=use_cache,
  826. **kwargs,
  827. )
  828. lm_logits = self.lm_head(outputs.last_hidden_state) + self.final_logits_bias
  829. masked_lm_loss = None
  830. if labels is not None:
  831. loss_fct = CrossEntropyLoss()
  832. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  833. return Seq2SeqLMOutput(
  834. loss=masked_lm_loss,
  835. logits=lm_logits,
  836. past_key_values=outputs.past_key_values,
  837. decoder_hidden_states=outputs.decoder_hidden_states,
  838. decoder_attentions=outputs.decoder_attentions,
  839. cross_attentions=outputs.cross_attentions,
  840. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  841. encoder_hidden_states=outputs.encoder_hidden_states,
  842. encoder_attentions=outputs.encoder_attentions,
  843. )
  844. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  845. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  846. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
  847. class PegasusDecoderWrapper(PegasusPreTrainedModel):
  848. """
  849. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  850. used in combination with the [`EncoderDecoderModel`] framework.
  851. """
  852. def __init__(self, config):
  853. super().__init__(config)
  854. self.decoder = PegasusDecoder(config)
  855. self.post_init()
  856. def forward(self, *args, **kwargs):
  857. return self.decoder(*args, **kwargs)
  858. class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin):
  859. _tied_weights_keys = {
  860. "lm_head.weight": "model.decoder.embed_tokens.weight",
  861. }
  862. def __init__(self, config):
  863. config = copy.deepcopy(config)
  864. config.is_decoder = True
  865. config.is_encoder_decoder = False
  866. super().__init__(config)
  867. self.model = PegasusDecoderWrapper(config)
  868. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  869. # Initialize weights and apply final processing
  870. self.post_init()
  871. def get_input_embeddings(self):
  872. return self.model.decoder.embed_tokens
  873. def set_input_embeddings(self, value):
  874. self.model.decoder.embed_tokens = value
  875. def get_position_embeddings(self) -> nn.Embedding:
  876. """
  877. Returns the position embeddings matrix
  878. """
  879. return self.model.decoder.get_position_embeddings()
  880. def resize_position_embeddings(self, new_num_position_embeddings: int):
  881. """
  882. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  883. config.max_position_embeddings`.
  884. Arguments:
  885. new_num_position_embeddings (`int`):
  886. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  887. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  888. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  889. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  890. will remove vectors from the end.
  891. """
  892. self.config.max_position_embeddings = new_num_position_embeddings
  893. self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
  894. @can_return_tuple
  895. @auto_docstring
  896. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
  897. def forward(
  898. self,
  899. input_ids: torch.LongTensor | None = None,
  900. attention_mask: torch.Tensor | None = None,
  901. encoder_hidden_states: torch.FloatTensor | None = None,
  902. encoder_attention_mask: torch.FloatTensor | None = None,
  903. past_key_values: Cache | None = None,
  904. inputs_embeds: torch.FloatTensor | None = None,
  905. labels: torch.LongTensor | None = None,
  906. use_cache: bool | None = None,
  907. logits_to_keep: int | torch.Tensor = 0,
  908. **kwargs: Unpack[TransformersKwargs],
  909. ) -> tuple | CausalLMOutputWithCrossAttentions:
  910. r"""
  911. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  912. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  913. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  914. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  915. Example:
  916. ```python
  917. >>> from transformers import AutoTokenizer, PegasusForCausalLM
  918. >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
  919. >>> model = PegasusForCausalLM.from_pretrained("google/pegasus-large")
  920. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  921. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  922. >>> outputs = model(**inputs)
  923. >>> logits = outputs.logits
  924. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  925. >>> list(logits.shape) == expected_shape
  926. True
  927. ```"""
  928. outputs: BaseModelOutputWithPastAndCrossAttentions = self.model.decoder(
  929. input_ids=input_ids,
  930. attention_mask=attention_mask,
  931. encoder_hidden_states=encoder_hidden_states,
  932. encoder_attention_mask=encoder_attention_mask,
  933. past_key_values=past_key_values,
  934. inputs_embeds=inputs_embeds,
  935. use_cache=use_cache,
  936. **kwargs,
  937. )
  938. hidden_states = outputs[0]
  939. # Only compute necessary logits
  940. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  941. logits = self.lm_head(hidden_states[:, slice_indices, :])
  942. loss = None
  943. if labels is not None:
  944. labels = labels.to(logits.device)
  945. loss_fct = CrossEntropyLoss()
  946. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  947. return CausalLMOutputWithCrossAttentions(
  948. loss=loss,
  949. logits=logits,
  950. past_key_values=outputs.past_key_values,
  951. hidden_states=outputs.hidden_states,
  952. attentions=outputs.attentions,
  953. cross_attentions=outputs.cross_attentions,
  954. )
  955. __all__ = ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]