| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321 |
- # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch BART model."""
- import math
- import warnings
- from collections.abc import Callable
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
- from ...generation import GenerationMixin
- from ...masking_utils import create_bidirectional_mask, create_causal_mask
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPastAndCrossAttentions,
- CausalLMOutputWithCrossAttentions,
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- Seq2SeqQuestionAnsweringModelOutput,
- Seq2SeqSequenceClassifierOutput,
- )
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import (
- TransformersKwargs,
- auto_docstring,
- can_return_tuple,
- is_torchdynamo_compiling,
- logging,
- torch_compilable_check,
- )
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import OutputRecorder, capture_outputs
- from .configuration_bart import BartConfig
- logger = logging.get_logger(__name__)
- def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
- """
- Shift input ids one token to the right.
- """
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
- shifted_input_ids[:, 0] = decoder_start_token_id
- if pad_token_id is None:
- raise ValueError("self.model.config.pad_token_id has to be defined.")
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- return shifted_input_ids
- class BartLearnedPositionalEmbedding(nn.Embedding):
- """
- This module learns positional embeddings up to a fixed maximum size.
- """
- def __init__(self, num_embeddings: int, embedding_dim: int):
- # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
- # and adjust num_embeddings appropriately. Other models don't have this hack
- self.offset = 2
- super().__init__(num_embeddings + self.offset, embedding_dim)
- def forward(
- self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None
- ):
- """`input_ids' shape is expected to be [bsz x seqlen]."""
- if position_ids is None:
- bsz, seq_len = input_ids.shape[:2]
- position_ids = torch.arange(
- past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
- ).expand(bsz, -1)
- else:
- position_ids = position_ids.unsqueeze(0)
- return super().forward(position_ids + self.offset)
- class BartScaledWordEmbedding(nn.Embedding):
- """
- This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
- """
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0):
- super().__init__(num_embeddings, embedding_dim, padding_idx)
- self.embed_scale = embed_scale
- def forward(self, input_ids: torch.Tensor):
- return super().forward(input_ids) * self.embed_scale
- # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class BartAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- dropout: float = 0.0,
- is_decoder: bool = False,
- bias: bool = True,
- is_causal: bool = False,
- config: BartConfig | None = None,
- layer_idx: int | None = None,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.head_dim = embed_dim // num_heads
- self.config = config
- if (self.head_dim * num_heads) != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
- f" and `num_heads`: {num_heads})."
- )
- self.scaling = self.head_dim**-0.5
- self.is_decoder = is_decoder
- self.is_causal = is_causal
- self.layer_idx = layer_idx
- if layer_idx is None and self.is_decoder:
- logger.warning_once(
- f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
- "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- attention_mask: torch.Tensor | None = None,
- # TODO: we need a refactor so that the different attention modules can get their specific kwargs
- # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- """Input shape: Batch x Time x Channel"""
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- # determine input shapes
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- # get query proj
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- is_updated = False
- if past_key_values is not None:
- if isinstance(past_key_values, EncoderDecoderCache):
- is_updated = past_key_values.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_values = past_key_values.cross_attention_cache
- else:
- curr_past_key_values = past_key_values.self_attention_cache
- else:
- curr_past_key_values = past_key_values
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_values is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_values.layers[self.layer_idx].keys
- value_states = curr_past_key_values.layers[self.layer_idx].values
- else:
- key_states = self.k_proj(current_states)
- value_states = self.v_proj(current_states)
- kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
- key_states = key_states.view(kv_shape).transpose(1, 2)
- value_states = value_states.view(kv_shape).transpose(1, 2)
- if past_key_values is not None:
- key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
- past_key_values.is_updated[self.layer_idx] = True
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- class BartEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: BartConfig, layer_idx: int | None = None):
- super().__init__()
- self.embed_dim = config.d_model
- self.self_attn = BartAttention(
- embed_dim=self.embed_dim,
- num_heads=config.encoder_attention_heads,
- dropout=config.attention_dropout,
- config=config,
- layer_idx=layer_idx,
- )
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- attention_mask: torch.FloatTensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states, _ = self.self_attn(
- hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- residual = hidden_states
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- if hidden_states.dtype == torch.float16 and not torch.isfinite(hidden_states).all():
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- return hidden_states
- class BartDecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: BartConfig, layer_idx: int | None = None):
- super().__init__()
- self.embed_dim = config.d_model
- self.self_attn = BartAttention(
- embed_dim=self.embed_dim,
- num_heads=config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- is_causal=True,
- config=config,
- layer_idx=layer_idx,
- )
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = BartAttention(
- self.embed_dim,
- config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- config=config,
- layer_idx=layer_idx,
- )
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = True,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- # Self Attention
- hidden_states, _ = self.self_attn(
- hidden_states,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- # Cross-Attention Block
- if encoder_hidden_states is not None:
- residual = hidden_states
- hidden_states, _ = self.encoder_attn(
- hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
- # Fully Connected
- residual = hidden_states
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- return hidden_states
- class BartClassificationHead(nn.Module):
- """Head for sentence-level classification tasks."""
- def __init__(
- self,
- input_dim: int,
- inner_dim: int,
- num_classes: int,
- pooler_dropout: float,
- ):
- super().__init__()
- self.dense = nn.Linear(input_dim, inner_dim)
- self.dropout = nn.Dropout(p=pooler_dropout)
- self.out_proj = nn.Linear(inner_dim, num_classes)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.dense(hidden_states)
- hidden_states = torch.tanh(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.out_proj(hidden_states)
- return hidden_states
- @auto_docstring
- class BartPreTrainedModel(PreTrainedModel):
- config: BartConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
- _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _can_compile_fullgraph = True
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, BartForConditionalGeneration):
- init.zeros_(module.final_logits_bias)
- @property
- def dummy_inputs(self):
- pad_token = self.config.pad_token_id
- input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
- dummy_inputs = {
- "attention_mask": input_ids.ne(pad_token),
- "input_ids": input_ids,
- }
- return dummy_inputs
- class PretrainedBartModel(BartPreTrainedModel):
- def __init_subclass__(self):
- warnings.warn(
- "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
- FutureWarning,
- )
- class BartPretrainedModel(BartPreTrainedModel):
- def __init_subclass__(self):
- warnings.warn(
- "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
- FutureWarning,
- )
- class BartEncoder(BartPreTrainedModel):
- """
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
- [`BartEncoderLayer`].
- Args:
- config: BartConfig
- embed_tokens (nn.Embedding): output embedding
- """
- _can_record_outputs = {
- "hidden_states": BartEncoderLayer,
- "attentions": BartAttention,
- }
- def __init__(self, config: BartConfig):
- super().__init__(config)
- self.dropout = config.dropout
- self.layerdrop = config.encoder_layerdrop
- embed_dim = config.d_model
- self.padding_idx = config.pad_token_id
- self.max_source_positions = config.max_position_embeddings
- embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
- self.embed_tokens = BartScaledWordEmbedding(
- config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
- )
- self.embed_positions = BartLearnedPositionalEmbedding(
- config.max_position_embeddings,
- embed_dim,
- )
- self.layers = nn.ModuleList([BartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)])
- self.layernorm_embedding = nn.LayerNorm(embed_dim)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- embed_pos = self.embed_positions(inputs_embeds[:, :, -1]) # needed for the shape only
- embed_pos = embed_pos.to(inputs_embeds.device)
- hidden_states = inputs_embeds + embed_pos
- hidden_states = self.layernorm_embedding(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- )
- for idx, encoder_layer in enumerate(self.layers):
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- to_drop = False
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop: # skip the layer
- to_drop = True
- if not to_drop:
- hidden_states = encoder_layer(
- hidden_states,
- attention_mask,
- **kwargs,
- )
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- )
- class BartDecoder(BartPreTrainedModel):
- """
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
- Args:
- config: BartConfig
- embed_tokens (nn.Embedding): output embedding
- """
- _can_record_outputs = {
- "hidden_states": BartDecoderLayer,
- "attentions": OutputRecorder(BartAttention, index=1, layer_name="self_attn"),
- "cross_attentions": OutputRecorder(BartAttention, index=1, layer_name="encoder_attn"),
- }
- def __init__(self, config: BartConfig):
- super().__init__(config)
- self.dropout = config.dropout
- self.layerdrop = config.decoder_layerdrop
- self.padding_idx = config.pad_token_id
- self.max_target_positions = config.max_position_embeddings
- embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
- self.embed_tokens = BartScaledWordEmbedding(
- config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
- )
- self.embed_positions = BartLearnedPositionalEmbedding(
- config.max_position_embeddings,
- config.d_model,
- )
- self.layers = nn.ModuleList([BartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
- self.layernorm_embedding = nn.LayerNorm(config.d_model)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPastAndCrossAttentions:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of decoder_input_ids or decoder_inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- # initialize `past_key_values`
- if use_cache and past_key_values is None:
- past_key_values = (
- EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
- if encoder_hidden_states is not None or self.config.is_encoder_decoder
- else DynamicCache(config=self.config)
- )
- batch_size, seq_length = inputs_embeds.size()[:-1]
- past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
- if attention_mask is None and not is_torchdynamo_compiling():
- # required mask seq length can be calculated via length of past cache
- mask_seq_length = past_key_values_length + seq_length
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
- self_attn_cache = (
- past_key_values.self_attention_cache
- if isinstance(past_key_values, EncoderDecoderCache)
- else past_key_values
- )
- attention_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=self_attn_cache,
- )
- encoder_attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=encoder_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- )
- # embed positions
- positions = self.embed_positions(input, past_key_values_length, position_ids=position_ids)
- positions = positions.to(inputs_embeds.device)
- hidden_states = inputs_embeds + positions
- hidden_states = self.layernorm_embedding(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- for idx, decoder_layer in enumerate(self.layers):
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop:
- continue
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask,
- encoder_hidden_states, # as a positional argument for gradient checkpointing
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- **kwargs,
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- @auto_docstring
- class BartModel(BartPreTrainedModel):
- _tied_weights_keys = {
- "decoder.embed_tokens.weight": "shared.weight",
- "encoder.embed_tokens.weight": "shared.weight",
- }
- def __init__(self, config: BartConfig):
- super().__init__(config)
- padding_idx, vocab_size = config.pad_token_id, config.vocab_size
- embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
- self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
- self.encoder = BartEncoder(config)
- self.decoder = BartDecoder(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, value):
- self.shared = value
- self.encoder.embed_tokens = self.shared
- self.decoder.embed_tokens = self.shared
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- encoder_outputs: list[torch.FloatTensor] | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Seq2SeqModelOutput:
- r"""
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
- For translation and summarization training, `decoder_input_ids` should be provided. If no
- `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
- for denoising pre-training following the paper.
- decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
- information on the default strategy.
- """
- # different to other models, Bart automatically creates decoder_input_ids from
- # input_ids if no decoder_input_ids are provided
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- if input_ids is None:
- raise ValueError(
- "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
- "passed, `input_ids` cannot be `None`. Please pass either "
- "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
- )
- decoder_input_ids = shift_tokens_right(
- input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
- )
- if encoder_outputs is None:
- encoder_outputs: BaseModelOutput = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
- elif not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
- decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- encoder_hidden_states=encoder_outputs[0],
- encoder_attention_mask=attention_mask,
- past_key_values=past_key_values,
- inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- The BART Model with a language modeling head. Can be used for summarization.
- """
- )
- class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
- base_model_prefix = "model"
- _tied_weights_keys = {
- "lm_head.weight": "model.shared.weight",
- }
- _keys_to_ignore_on_load_missing = ["final_logits_bias"]
- def __init__(self, config: BartConfig):
- super().__init__(config)
- self.model = BartModel(config)
- self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
- self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def resize_token_embeddings(
- self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
- ) -> nn.Embedding:
- new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
- self._resize_final_logits_bias(new_embeddings.weight.shape[0])
- return new_embeddings
- def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
- old_num_tokens = self.final_logits_bias.shape[-1]
- if new_num_tokens <= old_num_tokens:
- new_bias = self.final_logits_bias[:, :new_num_tokens]
- else:
- extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
- new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
- self.register_buffer("final_logits_bias", new_bias)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- encoder_outputs: list[torch.FloatTensor] | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Seq2SeqLMOutput:
- r"""
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
- For translation and summarization training, `decoder_input_ids` should be provided. If no
- `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
- for denoising pre-training following the paper.
- decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
- information on the default strategy.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example summarization:
- ```python
- >>> from transformers import AutoTokenizer, BartForConditionalGeneration
- >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
- >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
- >>> ARTICLE_TO_SUMMARIZE = (
- ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
- ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
- ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
- ... )
- >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
- >>> # Generate Summary
- >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
- >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
- ```
- Mask filling example:
- ```python
- >>> from transformers import AutoTokenizer, BartForConditionalGeneration
- >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
- >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
- >>> TXT = "My friends are <mask> but they eat too many carbs."
- >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
- >>> logits = model(input_ids).logits
- >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
- >>> probs = logits[0, masked_index].softmax(dim=0)
- >>> values, predictions = probs.topk(5)
- >>> tokenizer.decode(predictions).split()
- ['not', 'good', 'healthy', 'great', 'very']
- ```
- """
- if labels is not None:
- if use_cache:
- logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
- use_cache = False
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- decoder_input_ids = shift_tokens_right(
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
- )
- outputs: Seq2SeqModelOutput = self.model(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- encoder_outputs=encoder_outputs,
- decoder_attention_mask=decoder_attention_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- lm_logits = self.lm_head(outputs[0])
- lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
- masked_lm_loss = None
- if labels is not None:
- labels = labels.to(lm_logits.device)
- loss_fct = CrossEntropyLoss()
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
- return Seq2SeqLMOutput(
- loss=masked_lm_loss,
- logits=lm_logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
- def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
- return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
- @auto_docstring(
- custom_intro="""
- Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
- tasks.
- """
- )
- class BartForSequenceClassification(BartPreTrainedModel):
- def __init__(self, config: BartConfig, **kwargs):
- super().__init__(config, **kwargs)
- self.model = BartModel(config)
- self.classification_head = BartClassificationHead(
- config.d_model,
- config.d_model,
- config.num_labels,
- config.classifier_dropout,
- )
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- encoder_outputs: list[torch.FloatTensor] | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Seq2SeqSequenceClassifierOutput:
- r"""
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
- For translation and summarization training, `decoder_input_ids` should be provided. If no
- `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
- for denoising pre-training following the paper.
- decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
- information on the default strategy.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- if labels is not None:
- use_cache = False
- if input_ids is None and inputs_embeds is not None:
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
- )
- outputs: Seq2SeqModelOutput = self.model(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs[0] # last hidden state
- eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
- torch_compilable_check(
- torch.unique_consecutive(eos_mask.sum(1)).numel() == 1,
- "All examples must have the same number of <eos> tokens.",
- )
- sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
- :, -1, :
- ]
- logits = self.classification_head(sentence_representation)
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.config.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.config.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- return Seq2SeqSequenceClassifierOutput(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
- @auto_docstring
- class BartForQuestionAnswering(BartPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- config.num_labels = 2
- self.num_labels = config.num_labels
- self.model = BartModel(config)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- encoder_outputs: list[torch.FloatTensor] | None = None,
- start_positions: torch.LongTensor | None = None,
- end_positions: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Seq2SeqQuestionAnsweringModelOutput:
- r"""
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
- For translation and summarization training, `decoder_input_ids` should be provided. If no
- `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
- for denoising pre-training following the paper.
- decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
- information on the default strategy.
- """
- if start_positions is not None and end_positions is not None:
- use_cache = False
- outputs: Seq2SeqModelOutput = self.model(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- return Seq2SeqQuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
- class BartDecoderWrapper(BartPreTrainedModel):
- """
- This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
- used in combination with the [`EncoderDecoderModel`] framework.
- """
- def __init__(self, config):
- super().__init__(config)
- self.decoder = BartDecoder(config)
- self.post_init()
- def forward(self, *args, **kwargs):
- return self.decoder(*args, **kwargs)
- @auto_docstring(
- custom_intro="""
- BART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
- """
- )
- class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {
- "lm_head.weight": "model.decoder.embed_tokens.weight",
- }
- def __init__(self, config):
- config.is_decoder = True
- config.is_encoder_decoder = False
- super().__init__(config)
- self.model = BartDecoderWrapper(config)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.model.decoder.embed_tokens
- def set_input_embeddings(self, value):
- self.model.decoder.embed_tokens = value
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | CausalLMOutputWithCrossAttentions:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, BartForCausalLM
- >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
- >>> model = BartForCausalLM.from_pretrained("facebook/bart-base")
- >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> logits = outputs.logits
- >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
- >>> list(logits.shape) == expected_shape
- True
- ```"""
- outputs: BaseModelOutputWithPastAndCrossAttentions = self.model.decoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs[0]
- # Only compute necessary logits
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
- return CausalLMOutputWithCrossAttentions(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- __all__ = [
- "BartForCausalLM",
- "BartForConditionalGeneration",
- "BartForQuestionAnswering",
- "BartForSequenceClassification",
- "BartModel",
- "BartPreTrainedModel",
- "BartPretrainedModel",
- "PretrainedBartModel",
- ]
|