modeling_bart.py 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321
  1. # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch BART model."""
  15. import math
  16. import warnings
  17. from collections.abc import Callable
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  26. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPastAndCrossAttentions,
  31. CausalLMOutputWithCrossAttentions,
  32. Seq2SeqLMOutput,
  33. Seq2SeqModelOutput,
  34. Seq2SeqQuestionAnsweringModelOutput,
  35. Seq2SeqSequenceClassifierOutput,
  36. )
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import (
  40. TransformersKwargs,
  41. auto_docstring,
  42. can_return_tuple,
  43. is_torchdynamo_compiling,
  44. logging,
  45. torch_compilable_check,
  46. )
  47. from ...utils.generic import merge_with_config_defaults
  48. from ...utils.output_capturing import OutputRecorder, capture_outputs
  49. from .configuration_bart import BartConfig
  50. logger = logging.get_logger(__name__)
  51. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  52. """
  53. Shift input ids one token to the right.
  54. """
  55. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  56. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  57. shifted_input_ids[:, 0] = decoder_start_token_id
  58. if pad_token_id is None:
  59. raise ValueError("self.model.config.pad_token_id has to be defined.")
  60. # replace possible -100 values in labels by `pad_token_id`
  61. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  62. return shifted_input_ids
  63. class BartLearnedPositionalEmbedding(nn.Embedding):
  64. """
  65. This module learns positional embeddings up to a fixed maximum size.
  66. """
  67. def __init__(self, num_embeddings: int, embedding_dim: int):
  68. # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
  69. # and adjust num_embeddings appropriately. Other models don't have this hack
  70. self.offset = 2
  71. super().__init__(num_embeddings + self.offset, embedding_dim)
  72. def forward(
  73. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None
  74. ):
  75. """`input_ids' shape is expected to be [bsz x seqlen]."""
  76. if position_ids is None:
  77. bsz, seq_len = input_ids.shape[:2]
  78. position_ids = torch.arange(
  79. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  80. ).expand(bsz, -1)
  81. else:
  82. position_ids = position_ids.unsqueeze(0)
  83. return super().forward(position_ids + self.offset)
  84. class BartScaledWordEmbedding(nn.Embedding):
  85. """
  86. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  87. """
  88. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0):
  89. super().__init__(num_embeddings, embedding_dim, padding_idx)
  90. self.embed_scale = embed_scale
  91. def forward(self, input_ids: torch.Tensor):
  92. return super().forward(input_ids) * self.embed_scale
  93. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  94. def eager_attention_forward(
  95. module: nn.Module,
  96. query: torch.Tensor,
  97. key: torch.Tensor,
  98. value: torch.Tensor,
  99. attention_mask: torch.Tensor | None,
  100. scaling: float | None = None,
  101. dropout: float = 0.0,
  102. **kwargs: Unpack[TransformersKwargs],
  103. ):
  104. if scaling is None:
  105. scaling = query.size(-1) ** -0.5
  106. # Take the dot product between "query" and "key" to get the raw attention scores.
  107. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  108. if attention_mask is not None:
  109. attn_weights = attn_weights + attention_mask
  110. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  111. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  112. attn_output = torch.matmul(attn_weights, value)
  113. attn_output = attn_output.transpose(1, 2).contiguous()
  114. return attn_output, attn_weights
  115. class BartAttention(nn.Module):
  116. """Multi-headed attention from 'Attention Is All You Need' paper"""
  117. def __init__(
  118. self,
  119. embed_dim: int,
  120. num_heads: int,
  121. dropout: float = 0.0,
  122. is_decoder: bool = False,
  123. bias: bool = True,
  124. is_causal: bool = False,
  125. config: BartConfig | None = None,
  126. layer_idx: int | None = None,
  127. ):
  128. super().__init__()
  129. self.embed_dim = embed_dim
  130. self.num_heads = num_heads
  131. self.dropout = dropout
  132. self.head_dim = embed_dim // num_heads
  133. self.config = config
  134. if (self.head_dim * num_heads) != self.embed_dim:
  135. raise ValueError(
  136. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  137. f" and `num_heads`: {num_heads})."
  138. )
  139. self.scaling = self.head_dim**-0.5
  140. self.is_decoder = is_decoder
  141. self.is_causal = is_causal
  142. self.layer_idx = layer_idx
  143. if layer_idx is None and self.is_decoder:
  144. logger.warning_once(
  145. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  146. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  147. "when creating this class."
  148. )
  149. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  150. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  151. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  152. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  153. def forward(
  154. self,
  155. hidden_states: torch.Tensor,
  156. key_value_states: torch.Tensor | None = None,
  157. past_key_values: Cache | None = None,
  158. attention_mask: torch.Tensor | None = None,
  159. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  160. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  161. **kwargs: Unpack[FlashAttentionKwargs],
  162. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  163. """Input shape: Batch x Time x Channel"""
  164. # if key_value_states are provided this layer is used as a cross-attention layer
  165. # for the decoder
  166. is_cross_attention = key_value_states is not None
  167. # determine input shapes
  168. input_shape = hidden_states.shape[:-1]
  169. hidden_shape = (*input_shape, -1, self.head_dim)
  170. # get query proj
  171. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  172. is_updated = False
  173. if past_key_values is not None:
  174. if isinstance(past_key_values, EncoderDecoderCache):
  175. is_updated = past_key_values.is_updated.get(self.layer_idx)
  176. if is_cross_attention:
  177. # after the first generated id, we can subsequently re-use all key/value_states from cache
  178. curr_past_key_values = past_key_values.cross_attention_cache
  179. else:
  180. curr_past_key_values = past_key_values.self_attention_cache
  181. else:
  182. curr_past_key_values = past_key_values
  183. current_states = key_value_states if is_cross_attention else hidden_states
  184. if is_cross_attention and past_key_values is not None and is_updated:
  185. # reuse k,v, cross_attentions
  186. key_states = curr_past_key_values.layers[self.layer_idx].keys
  187. value_states = curr_past_key_values.layers[self.layer_idx].values
  188. else:
  189. key_states = self.k_proj(current_states)
  190. value_states = self.v_proj(current_states)
  191. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  192. key_states = key_states.view(kv_shape).transpose(1, 2)
  193. value_states = value_states.view(kv_shape).transpose(1, 2)
  194. if past_key_values is not None:
  195. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  196. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  197. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  198. past_key_values.is_updated[self.layer_idx] = True
  199. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  200. self.config._attn_implementation, eager_attention_forward
  201. )
  202. attn_output, attn_weights = attention_interface(
  203. self,
  204. query_states,
  205. key_states,
  206. value_states,
  207. attention_mask,
  208. dropout=0.0 if not self.training else self.dropout,
  209. scaling=self.scaling,
  210. **kwargs,
  211. )
  212. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  213. attn_output = self.out_proj(attn_output)
  214. return attn_output, attn_weights
  215. class BartEncoderLayer(GradientCheckpointingLayer):
  216. def __init__(self, config: BartConfig, layer_idx: int | None = None):
  217. super().__init__()
  218. self.embed_dim = config.d_model
  219. self.self_attn = BartAttention(
  220. embed_dim=self.embed_dim,
  221. num_heads=config.encoder_attention_heads,
  222. dropout=config.attention_dropout,
  223. config=config,
  224. layer_idx=layer_idx,
  225. )
  226. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  227. self.dropout = config.dropout
  228. self.activation_fn = ACT2FN[config.activation_function]
  229. self.activation_dropout = config.activation_dropout
  230. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  231. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  232. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  233. def forward(
  234. self,
  235. hidden_states: torch.FloatTensor,
  236. attention_mask: torch.FloatTensor,
  237. **kwargs: Unpack[TransformersKwargs],
  238. ) -> torch.Tensor:
  239. residual = hidden_states
  240. hidden_states, _ = self.self_attn(
  241. hidden_states,
  242. attention_mask=attention_mask,
  243. **kwargs,
  244. )
  245. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  246. hidden_states = residual + hidden_states
  247. hidden_states = self.self_attn_layer_norm(hidden_states)
  248. residual = hidden_states
  249. hidden_states = self.activation_fn(self.fc1(hidden_states))
  250. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  251. hidden_states = self.fc2(hidden_states)
  252. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  253. hidden_states = residual + hidden_states
  254. hidden_states = self.final_layer_norm(hidden_states)
  255. if hidden_states.dtype == torch.float16 and not torch.isfinite(hidden_states).all():
  256. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  257. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  258. return hidden_states
  259. class BartDecoderLayer(GradientCheckpointingLayer):
  260. def __init__(self, config: BartConfig, layer_idx: int | None = None):
  261. super().__init__()
  262. self.embed_dim = config.d_model
  263. self.self_attn = BartAttention(
  264. embed_dim=self.embed_dim,
  265. num_heads=config.decoder_attention_heads,
  266. dropout=config.attention_dropout,
  267. is_decoder=True,
  268. is_causal=True,
  269. config=config,
  270. layer_idx=layer_idx,
  271. )
  272. self.dropout = config.dropout
  273. self.activation_fn = ACT2FN[config.activation_function]
  274. self.activation_dropout = config.activation_dropout
  275. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  276. self.encoder_attn = BartAttention(
  277. self.embed_dim,
  278. config.decoder_attention_heads,
  279. dropout=config.attention_dropout,
  280. is_decoder=True,
  281. config=config,
  282. layer_idx=layer_idx,
  283. )
  284. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  285. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  286. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  287. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  288. def forward(
  289. self,
  290. hidden_states: torch.Tensor,
  291. attention_mask: torch.Tensor | None = None,
  292. encoder_hidden_states: torch.Tensor | None = None,
  293. encoder_attention_mask: torch.Tensor | None = None,
  294. past_key_values: Cache | None = None,
  295. use_cache: bool | None = True,
  296. **kwargs: Unpack[TransformersKwargs],
  297. ) -> torch.Tensor:
  298. residual = hidden_states
  299. # Self Attention
  300. hidden_states, _ = self.self_attn(
  301. hidden_states,
  302. past_key_values=past_key_values,
  303. attention_mask=attention_mask,
  304. **kwargs,
  305. )
  306. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  307. hidden_states = residual + hidden_states
  308. hidden_states = self.self_attn_layer_norm(hidden_states)
  309. # Cross-Attention Block
  310. if encoder_hidden_states is not None:
  311. residual = hidden_states
  312. hidden_states, _ = self.encoder_attn(
  313. hidden_states,
  314. key_value_states=encoder_hidden_states,
  315. attention_mask=encoder_attention_mask,
  316. past_key_values=past_key_values,
  317. **kwargs,
  318. )
  319. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  320. hidden_states = residual + hidden_states
  321. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  322. # Fully Connected
  323. residual = hidden_states
  324. hidden_states = self.activation_fn(self.fc1(hidden_states))
  325. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  326. hidden_states = self.fc2(hidden_states)
  327. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  328. hidden_states = residual + hidden_states
  329. hidden_states = self.final_layer_norm(hidden_states)
  330. return hidden_states
  331. class BartClassificationHead(nn.Module):
  332. """Head for sentence-level classification tasks."""
  333. def __init__(
  334. self,
  335. input_dim: int,
  336. inner_dim: int,
  337. num_classes: int,
  338. pooler_dropout: float,
  339. ):
  340. super().__init__()
  341. self.dense = nn.Linear(input_dim, inner_dim)
  342. self.dropout = nn.Dropout(p=pooler_dropout)
  343. self.out_proj = nn.Linear(inner_dim, num_classes)
  344. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  345. hidden_states = self.dropout(hidden_states)
  346. hidden_states = self.dense(hidden_states)
  347. hidden_states = torch.tanh(hidden_states)
  348. hidden_states = self.dropout(hidden_states)
  349. hidden_states = self.out_proj(hidden_states)
  350. return hidden_states
  351. @auto_docstring
  352. class BartPreTrainedModel(PreTrainedModel):
  353. config: BartConfig
  354. base_model_prefix = "model"
  355. supports_gradient_checkpointing = True
  356. _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
  357. _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
  358. _skip_keys_device_placement = "past_key_values"
  359. _supports_flash_attn = True
  360. _supports_sdpa = True
  361. _supports_flex_attn = True
  362. _can_compile_fullgraph = True
  363. def _init_weights(self, module):
  364. super()._init_weights(module)
  365. if isinstance(module, BartForConditionalGeneration):
  366. init.zeros_(module.final_logits_bias)
  367. @property
  368. def dummy_inputs(self):
  369. pad_token = self.config.pad_token_id
  370. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  371. dummy_inputs = {
  372. "attention_mask": input_ids.ne(pad_token),
  373. "input_ids": input_ids,
  374. }
  375. return dummy_inputs
  376. class PretrainedBartModel(BartPreTrainedModel):
  377. def __init_subclass__(self):
  378. warnings.warn(
  379. "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
  380. FutureWarning,
  381. )
  382. class BartPretrainedModel(BartPreTrainedModel):
  383. def __init_subclass__(self):
  384. warnings.warn(
  385. "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
  386. FutureWarning,
  387. )
  388. class BartEncoder(BartPreTrainedModel):
  389. """
  390. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  391. [`BartEncoderLayer`].
  392. Args:
  393. config: BartConfig
  394. embed_tokens (nn.Embedding): output embedding
  395. """
  396. _can_record_outputs = {
  397. "hidden_states": BartEncoderLayer,
  398. "attentions": BartAttention,
  399. }
  400. def __init__(self, config: BartConfig):
  401. super().__init__(config)
  402. self.dropout = config.dropout
  403. self.layerdrop = config.encoder_layerdrop
  404. embed_dim = config.d_model
  405. self.padding_idx = config.pad_token_id
  406. self.max_source_positions = config.max_position_embeddings
  407. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  408. self.embed_tokens = BartScaledWordEmbedding(
  409. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  410. )
  411. self.embed_positions = BartLearnedPositionalEmbedding(
  412. config.max_position_embeddings,
  413. embed_dim,
  414. )
  415. self.layers = nn.ModuleList([BartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)])
  416. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  417. self.gradient_checkpointing = False
  418. # Initialize weights and apply final processing
  419. self.post_init()
  420. @merge_with_config_defaults
  421. @capture_outputs
  422. @auto_docstring
  423. def forward(
  424. self,
  425. input_ids: torch.LongTensor | None = None,
  426. attention_mask: torch.Tensor | None = None,
  427. inputs_embeds: torch.FloatTensor | None = None,
  428. **kwargs: Unpack[TransformersKwargs],
  429. ) -> BaseModelOutput:
  430. if (input_ids is None) ^ (inputs_embeds is not None):
  431. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  432. if inputs_embeds is None:
  433. inputs_embeds = self.embed_tokens(input_ids)
  434. embed_pos = self.embed_positions(inputs_embeds[:, :, -1]) # needed for the shape only
  435. embed_pos = embed_pos.to(inputs_embeds.device)
  436. hidden_states = inputs_embeds + embed_pos
  437. hidden_states = self.layernorm_embedding(hidden_states)
  438. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  439. attention_mask = create_bidirectional_mask(
  440. config=self.config,
  441. inputs_embeds=inputs_embeds,
  442. attention_mask=attention_mask,
  443. )
  444. for idx, encoder_layer in enumerate(self.layers):
  445. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  446. to_drop = False
  447. if self.training:
  448. dropout_probability = torch.rand([])
  449. if dropout_probability < self.layerdrop: # skip the layer
  450. to_drop = True
  451. if not to_drop:
  452. hidden_states = encoder_layer(
  453. hidden_states,
  454. attention_mask,
  455. **kwargs,
  456. )
  457. return BaseModelOutput(
  458. last_hidden_state=hidden_states,
  459. )
  460. class BartDecoder(BartPreTrainedModel):
  461. """
  462. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
  463. Args:
  464. config: BartConfig
  465. embed_tokens (nn.Embedding): output embedding
  466. """
  467. _can_record_outputs = {
  468. "hidden_states": BartDecoderLayer,
  469. "attentions": OutputRecorder(BartAttention, index=1, layer_name="self_attn"),
  470. "cross_attentions": OutputRecorder(BartAttention, index=1, layer_name="encoder_attn"),
  471. }
  472. def __init__(self, config: BartConfig):
  473. super().__init__(config)
  474. self.dropout = config.dropout
  475. self.layerdrop = config.decoder_layerdrop
  476. self.padding_idx = config.pad_token_id
  477. self.max_target_positions = config.max_position_embeddings
  478. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  479. self.embed_tokens = BartScaledWordEmbedding(
  480. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  481. )
  482. self.embed_positions = BartLearnedPositionalEmbedding(
  483. config.max_position_embeddings,
  484. config.d_model,
  485. )
  486. self.layers = nn.ModuleList([BartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  487. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  488. self.gradient_checkpointing = False
  489. # Initialize weights and apply final processing
  490. self.post_init()
  491. @merge_with_config_defaults
  492. @capture_outputs
  493. @auto_docstring
  494. def forward(
  495. self,
  496. input_ids: torch.LongTensor | None = None,
  497. attention_mask: torch.Tensor | None = None,
  498. encoder_hidden_states: torch.FloatTensor | None = None,
  499. encoder_attention_mask: torch.LongTensor | None = None,
  500. past_key_values: Cache | None = None,
  501. inputs_embeds: torch.FloatTensor | None = None,
  502. use_cache: bool | None = None,
  503. **kwargs: Unpack[TransformersKwargs],
  504. ) -> BaseModelOutputWithPastAndCrossAttentions:
  505. if (input_ids is None) ^ (inputs_embeds is not None):
  506. raise ValueError("You must specify exactly one of decoder_input_ids or decoder_inputs_embeds")
  507. if inputs_embeds is None:
  508. inputs_embeds = self.embed_tokens(input_ids)
  509. # initialize `past_key_values`
  510. if use_cache and past_key_values is None:
  511. past_key_values = (
  512. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  513. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  514. else DynamicCache(config=self.config)
  515. )
  516. batch_size, seq_length = inputs_embeds.size()[:-1]
  517. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  518. position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
  519. if attention_mask is None and not is_torchdynamo_compiling():
  520. # required mask seq length can be calculated via length of past cache
  521. mask_seq_length = past_key_values_length + seq_length
  522. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  523. self_attn_cache = (
  524. past_key_values.self_attention_cache
  525. if isinstance(past_key_values, EncoderDecoderCache)
  526. else past_key_values
  527. )
  528. attention_mask = create_causal_mask(
  529. config=self.config,
  530. inputs_embeds=inputs_embeds,
  531. attention_mask=attention_mask,
  532. past_key_values=self_attn_cache,
  533. )
  534. encoder_attention_mask = create_bidirectional_mask(
  535. config=self.config,
  536. inputs_embeds=inputs_embeds,
  537. attention_mask=encoder_attention_mask,
  538. encoder_hidden_states=encoder_hidden_states,
  539. )
  540. # embed positions
  541. positions = self.embed_positions(input, past_key_values_length, position_ids=position_ids)
  542. positions = positions.to(inputs_embeds.device)
  543. hidden_states = inputs_embeds + positions
  544. hidden_states = self.layernorm_embedding(hidden_states)
  545. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  546. for idx, decoder_layer in enumerate(self.layers):
  547. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  548. if self.training:
  549. dropout_probability = torch.rand([])
  550. if dropout_probability < self.layerdrop:
  551. continue
  552. hidden_states = decoder_layer(
  553. hidden_states,
  554. attention_mask,
  555. encoder_hidden_states, # as a positional argument for gradient checkpointing
  556. encoder_attention_mask=encoder_attention_mask,
  557. past_key_values=past_key_values,
  558. use_cache=use_cache,
  559. **kwargs,
  560. )
  561. return BaseModelOutputWithPastAndCrossAttentions(
  562. last_hidden_state=hidden_states,
  563. past_key_values=past_key_values,
  564. )
  565. @auto_docstring
  566. class BartModel(BartPreTrainedModel):
  567. _tied_weights_keys = {
  568. "decoder.embed_tokens.weight": "shared.weight",
  569. "encoder.embed_tokens.weight": "shared.weight",
  570. }
  571. def __init__(self, config: BartConfig):
  572. super().__init__(config)
  573. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  574. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  575. self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  576. self.encoder = BartEncoder(config)
  577. self.decoder = BartDecoder(config)
  578. # Initialize weights and apply final processing
  579. self.post_init()
  580. def get_input_embeddings(self):
  581. return self.shared
  582. def set_input_embeddings(self, value):
  583. self.shared = value
  584. self.encoder.embed_tokens = self.shared
  585. self.decoder.embed_tokens = self.shared
  586. @can_return_tuple
  587. @auto_docstring
  588. def forward(
  589. self,
  590. input_ids: torch.LongTensor | None = None,
  591. attention_mask: torch.Tensor | None = None,
  592. decoder_input_ids: torch.LongTensor | None = None,
  593. decoder_attention_mask: torch.LongTensor | None = None,
  594. encoder_outputs: list[torch.FloatTensor] | None = None,
  595. past_key_values: Cache | None = None,
  596. inputs_embeds: torch.FloatTensor | None = None,
  597. decoder_inputs_embeds: torch.FloatTensor | None = None,
  598. use_cache: bool | None = None,
  599. **kwargs: Unpack[TransformersKwargs],
  600. ) -> tuple | Seq2SeqModelOutput:
  601. r"""
  602. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  603. Indices of decoder input sequence tokens in the vocabulary.
  604. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  605. [`PreTrainedTokenizer.__call__`] for details.
  606. [What are decoder input IDs?](../glossary#decoder-input-ids)
  607. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  608. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  609. For translation and summarization training, `decoder_input_ids` should be provided. If no
  610. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  611. for denoising pre-training following the paper.
  612. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  613. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  614. be used by default.
  615. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  616. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  617. information on the default strategy.
  618. """
  619. # different to other models, Bart automatically creates decoder_input_ids from
  620. # input_ids if no decoder_input_ids are provided
  621. if decoder_input_ids is None and decoder_inputs_embeds is None:
  622. if input_ids is None:
  623. raise ValueError(
  624. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  625. "passed, `input_ids` cannot be `None`. Please pass either "
  626. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  627. )
  628. decoder_input_ids = shift_tokens_right(
  629. input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
  630. )
  631. if encoder_outputs is None:
  632. encoder_outputs: BaseModelOutput = self.encoder(
  633. input_ids=input_ids,
  634. attention_mask=attention_mask,
  635. inputs_embeds=inputs_embeds,
  636. **kwargs,
  637. )
  638. elif not isinstance(encoder_outputs, BaseModelOutput):
  639. encoder_outputs = BaseModelOutput(
  640. last_hidden_state=encoder_outputs[0],
  641. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  642. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  643. )
  644. decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
  645. input_ids=decoder_input_ids,
  646. attention_mask=decoder_attention_mask,
  647. encoder_hidden_states=encoder_outputs[0],
  648. encoder_attention_mask=attention_mask,
  649. past_key_values=past_key_values,
  650. inputs_embeds=decoder_inputs_embeds,
  651. use_cache=use_cache,
  652. **kwargs,
  653. )
  654. return Seq2SeqModelOutput(
  655. last_hidden_state=decoder_outputs.last_hidden_state,
  656. past_key_values=decoder_outputs.past_key_values,
  657. decoder_hidden_states=decoder_outputs.hidden_states,
  658. decoder_attentions=decoder_outputs.attentions,
  659. cross_attentions=decoder_outputs.cross_attentions,
  660. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  661. encoder_hidden_states=encoder_outputs.hidden_states,
  662. encoder_attentions=encoder_outputs.attentions,
  663. )
  664. @auto_docstring(
  665. custom_intro="""
  666. The BART Model with a language modeling head. Can be used for summarization.
  667. """
  668. )
  669. class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
  670. base_model_prefix = "model"
  671. _tied_weights_keys = {
  672. "lm_head.weight": "model.shared.weight",
  673. }
  674. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  675. def __init__(self, config: BartConfig):
  676. super().__init__(config)
  677. self.model = BartModel(config)
  678. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  679. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  680. # Initialize weights and apply final processing
  681. self.post_init()
  682. def resize_token_embeddings(
  683. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  684. ) -> nn.Embedding:
  685. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  686. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  687. return new_embeddings
  688. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  689. old_num_tokens = self.final_logits_bias.shape[-1]
  690. if new_num_tokens <= old_num_tokens:
  691. new_bias = self.final_logits_bias[:, :new_num_tokens]
  692. else:
  693. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  694. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  695. self.register_buffer("final_logits_bias", new_bias)
  696. @can_return_tuple
  697. @auto_docstring
  698. def forward(
  699. self,
  700. input_ids: torch.LongTensor | None = None,
  701. attention_mask: torch.Tensor | None = None,
  702. decoder_input_ids: torch.LongTensor | None = None,
  703. decoder_attention_mask: torch.LongTensor | None = None,
  704. encoder_outputs: list[torch.FloatTensor] | None = None,
  705. past_key_values: Cache | None = None,
  706. inputs_embeds: torch.FloatTensor | None = None,
  707. decoder_inputs_embeds: torch.FloatTensor | None = None,
  708. labels: torch.LongTensor | None = None,
  709. use_cache: bool | None = None,
  710. **kwargs: Unpack[TransformersKwargs],
  711. ) -> tuple | Seq2SeqLMOutput:
  712. r"""
  713. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  714. Indices of decoder input sequence tokens in the vocabulary.
  715. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  716. [`PreTrainedTokenizer.__call__`] for details.
  717. [What are decoder input IDs?](../glossary#decoder-input-ids)
  718. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  719. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  720. For translation and summarization training, `decoder_input_ids` should be provided. If no
  721. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  722. for denoising pre-training following the paper.
  723. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  724. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  725. be used by default.
  726. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  727. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  728. information on the default strategy.
  729. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  730. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  731. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  732. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  733. Example summarization:
  734. ```python
  735. >>> from transformers import AutoTokenizer, BartForConditionalGeneration
  736. >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
  737. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
  738. >>> ARTICLE_TO_SUMMARIZE = (
  739. ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
  740. ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
  741. ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
  742. ... )
  743. >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
  744. >>> # Generate Summary
  745. >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
  746. >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  747. 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
  748. ```
  749. Mask filling example:
  750. ```python
  751. >>> from transformers import AutoTokenizer, BartForConditionalGeneration
  752. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
  753. >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
  754. >>> TXT = "My friends are <mask> but they eat too many carbs."
  755. >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
  756. >>> logits = model(input_ids).logits
  757. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  758. >>> probs = logits[0, masked_index].softmax(dim=0)
  759. >>> values, predictions = probs.topk(5)
  760. >>> tokenizer.decode(predictions).split()
  761. ['not', 'good', 'healthy', 'great', 'very']
  762. ```
  763. """
  764. if labels is not None:
  765. if use_cache:
  766. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  767. use_cache = False
  768. if decoder_input_ids is None and decoder_inputs_embeds is None:
  769. decoder_input_ids = shift_tokens_right(
  770. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  771. )
  772. outputs: Seq2SeqModelOutput = self.model(
  773. input_ids,
  774. attention_mask=attention_mask,
  775. decoder_input_ids=decoder_input_ids,
  776. encoder_outputs=encoder_outputs,
  777. decoder_attention_mask=decoder_attention_mask,
  778. past_key_values=past_key_values,
  779. inputs_embeds=inputs_embeds,
  780. decoder_inputs_embeds=decoder_inputs_embeds,
  781. use_cache=use_cache,
  782. **kwargs,
  783. )
  784. lm_logits = self.lm_head(outputs[0])
  785. lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
  786. masked_lm_loss = None
  787. if labels is not None:
  788. labels = labels.to(lm_logits.device)
  789. loss_fct = CrossEntropyLoss()
  790. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  791. return Seq2SeqLMOutput(
  792. loss=masked_lm_loss,
  793. logits=lm_logits,
  794. past_key_values=outputs.past_key_values,
  795. decoder_hidden_states=outputs.decoder_hidden_states,
  796. decoder_attentions=outputs.decoder_attentions,
  797. cross_attentions=outputs.cross_attentions,
  798. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  799. encoder_hidden_states=outputs.encoder_hidden_states,
  800. encoder_attentions=outputs.encoder_attentions,
  801. )
  802. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  803. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  804. @auto_docstring(
  805. custom_intro="""
  806. Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  807. tasks.
  808. """
  809. )
  810. class BartForSequenceClassification(BartPreTrainedModel):
  811. def __init__(self, config: BartConfig, **kwargs):
  812. super().__init__(config, **kwargs)
  813. self.model = BartModel(config)
  814. self.classification_head = BartClassificationHead(
  815. config.d_model,
  816. config.d_model,
  817. config.num_labels,
  818. config.classifier_dropout,
  819. )
  820. # Initialize weights and apply final processing
  821. self.post_init()
  822. @can_return_tuple
  823. @auto_docstring
  824. def forward(
  825. self,
  826. input_ids: torch.LongTensor | None = None,
  827. attention_mask: torch.Tensor | None = None,
  828. decoder_input_ids: torch.LongTensor | None = None,
  829. decoder_attention_mask: torch.LongTensor | None = None,
  830. encoder_outputs: list[torch.FloatTensor] | None = None,
  831. inputs_embeds: torch.FloatTensor | None = None,
  832. decoder_inputs_embeds: torch.FloatTensor | None = None,
  833. labels: torch.LongTensor | None = None,
  834. use_cache: bool | None = None,
  835. **kwargs: Unpack[TransformersKwargs],
  836. ) -> tuple | Seq2SeqSequenceClassifierOutput:
  837. r"""
  838. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  839. Indices of decoder input sequence tokens in the vocabulary.
  840. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  841. [`PreTrainedTokenizer.__call__`] for details.
  842. [What are decoder input IDs?](../glossary#decoder-input-ids)
  843. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  844. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  845. For translation and summarization training, `decoder_input_ids` should be provided. If no
  846. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  847. for denoising pre-training following the paper.
  848. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  849. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  850. be used by default.
  851. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  852. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  853. information on the default strategy.
  854. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  855. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  856. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  857. """
  858. if labels is not None:
  859. use_cache = False
  860. if input_ids is None and inputs_embeds is not None:
  861. raise NotImplementedError(
  862. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  863. )
  864. outputs: Seq2SeqModelOutput = self.model(
  865. input_ids,
  866. attention_mask=attention_mask,
  867. decoder_input_ids=decoder_input_ids,
  868. decoder_attention_mask=decoder_attention_mask,
  869. encoder_outputs=encoder_outputs,
  870. inputs_embeds=inputs_embeds,
  871. decoder_inputs_embeds=decoder_inputs_embeds,
  872. use_cache=use_cache,
  873. **kwargs,
  874. )
  875. hidden_states = outputs[0] # last hidden state
  876. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  877. torch_compilable_check(
  878. torch.unique_consecutive(eos_mask.sum(1)).numel() == 1,
  879. "All examples must have the same number of <eos> tokens.",
  880. )
  881. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  882. :, -1, :
  883. ]
  884. logits = self.classification_head(sentence_representation)
  885. loss = None
  886. if labels is not None:
  887. labels = labels.to(logits.device)
  888. if self.config.problem_type is None:
  889. if self.config.num_labels == 1:
  890. self.config.problem_type = "regression"
  891. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  892. self.config.problem_type = "single_label_classification"
  893. else:
  894. self.config.problem_type = "multi_label_classification"
  895. if self.config.problem_type == "regression":
  896. loss_fct = MSELoss()
  897. if self.config.num_labels == 1:
  898. loss = loss_fct(logits.squeeze(), labels.squeeze())
  899. else:
  900. loss = loss_fct(logits, labels)
  901. elif self.config.problem_type == "single_label_classification":
  902. loss_fct = CrossEntropyLoss()
  903. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  904. elif self.config.problem_type == "multi_label_classification":
  905. loss_fct = BCEWithLogitsLoss()
  906. loss = loss_fct(logits, labels)
  907. return Seq2SeqSequenceClassifierOutput(
  908. loss=loss,
  909. logits=logits,
  910. past_key_values=outputs.past_key_values,
  911. decoder_hidden_states=outputs.decoder_hidden_states,
  912. decoder_attentions=outputs.decoder_attentions,
  913. cross_attentions=outputs.cross_attentions,
  914. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  915. encoder_hidden_states=outputs.encoder_hidden_states,
  916. encoder_attentions=outputs.encoder_attentions,
  917. )
  918. @auto_docstring
  919. class BartForQuestionAnswering(BartPreTrainedModel):
  920. def __init__(self, config):
  921. super().__init__(config)
  922. config.num_labels = 2
  923. self.num_labels = config.num_labels
  924. self.model = BartModel(config)
  925. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  926. # Initialize weights and apply final processing
  927. self.post_init()
  928. @can_return_tuple
  929. @auto_docstring
  930. def forward(
  931. self,
  932. input_ids: torch.Tensor | None = None,
  933. attention_mask: torch.Tensor | None = None,
  934. decoder_input_ids: torch.LongTensor | None = None,
  935. decoder_attention_mask: torch.LongTensor | None = None,
  936. encoder_outputs: list[torch.FloatTensor] | None = None,
  937. start_positions: torch.LongTensor | None = None,
  938. end_positions: torch.LongTensor | None = None,
  939. inputs_embeds: torch.FloatTensor | None = None,
  940. decoder_inputs_embeds: torch.FloatTensor | None = None,
  941. use_cache: bool | None = None,
  942. **kwargs: Unpack[TransformersKwargs],
  943. ) -> tuple | Seq2SeqQuestionAnsweringModelOutput:
  944. r"""
  945. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  946. Indices of decoder input sequence tokens in the vocabulary.
  947. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  948. [`PreTrainedTokenizer.__call__`] for details.
  949. [What are decoder input IDs?](../glossary#decoder-input-ids)
  950. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  951. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  952. For translation and summarization training, `decoder_input_ids` should be provided. If no
  953. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  954. for denoising pre-training following the paper.
  955. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  956. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  957. be used by default.
  958. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  959. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  960. information on the default strategy.
  961. """
  962. if start_positions is not None and end_positions is not None:
  963. use_cache = False
  964. outputs: Seq2SeqModelOutput = self.model(
  965. input_ids,
  966. attention_mask=attention_mask,
  967. decoder_input_ids=decoder_input_ids,
  968. decoder_attention_mask=decoder_attention_mask,
  969. encoder_outputs=encoder_outputs,
  970. inputs_embeds=inputs_embeds,
  971. decoder_inputs_embeds=decoder_inputs_embeds,
  972. use_cache=use_cache,
  973. **kwargs,
  974. )
  975. sequence_output = outputs[0]
  976. logits = self.qa_outputs(sequence_output)
  977. start_logits, end_logits = logits.split(1, dim=-1)
  978. start_logits = start_logits.squeeze(-1).contiguous()
  979. end_logits = end_logits.squeeze(-1).contiguous()
  980. total_loss = None
  981. if start_positions is not None and end_positions is not None:
  982. # If we are on multi-GPU, split add a dimension
  983. if len(start_positions.size()) > 1:
  984. start_positions = start_positions.squeeze(-1)
  985. if len(end_positions.size()) > 1:
  986. end_positions = end_positions.squeeze(-1)
  987. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  988. ignored_index = start_logits.size(1)
  989. start_positions = start_positions.clamp(0, ignored_index)
  990. end_positions = end_positions.clamp(0, ignored_index)
  991. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  992. start_loss = loss_fct(start_logits, start_positions)
  993. end_loss = loss_fct(end_logits, end_positions)
  994. total_loss = (start_loss + end_loss) / 2
  995. return Seq2SeqQuestionAnsweringModelOutput(
  996. loss=total_loss,
  997. start_logits=start_logits,
  998. end_logits=end_logits,
  999. past_key_values=outputs.past_key_values,
  1000. decoder_hidden_states=outputs.decoder_hidden_states,
  1001. decoder_attentions=outputs.decoder_attentions,
  1002. cross_attentions=outputs.cross_attentions,
  1003. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1004. encoder_hidden_states=outputs.encoder_hidden_states,
  1005. encoder_attentions=outputs.encoder_attentions,
  1006. )
  1007. class BartDecoderWrapper(BartPreTrainedModel):
  1008. """
  1009. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1010. used in combination with the [`EncoderDecoderModel`] framework.
  1011. """
  1012. def __init__(self, config):
  1013. super().__init__(config)
  1014. self.decoder = BartDecoder(config)
  1015. self.post_init()
  1016. def forward(self, *args, **kwargs):
  1017. return self.decoder(*args, **kwargs)
  1018. @auto_docstring(
  1019. custom_intro="""
  1020. BART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
  1021. """
  1022. )
  1023. class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
  1024. _tied_weights_keys = {
  1025. "lm_head.weight": "model.decoder.embed_tokens.weight",
  1026. }
  1027. def __init__(self, config):
  1028. config.is_decoder = True
  1029. config.is_encoder_decoder = False
  1030. super().__init__(config)
  1031. self.model = BartDecoderWrapper(config)
  1032. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1033. # Initialize weights and apply final processing
  1034. self.post_init()
  1035. def get_input_embeddings(self):
  1036. return self.model.decoder.embed_tokens
  1037. def set_input_embeddings(self, value):
  1038. self.model.decoder.embed_tokens = value
  1039. @can_return_tuple
  1040. @auto_docstring
  1041. def forward(
  1042. self,
  1043. input_ids: torch.LongTensor | None = None,
  1044. attention_mask: torch.Tensor | None = None,
  1045. encoder_hidden_states: torch.FloatTensor | None = None,
  1046. encoder_attention_mask: torch.FloatTensor | None = None,
  1047. past_key_values: Cache | None = None,
  1048. inputs_embeds: torch.FloatTensor | None = None,
  1049. labels: torch.LongTensor | None = None,
  1050. use_cache: bool | None = None,
  1051. logits_to_keep: int | torch.Tensor = 0,
  1052. **kwargs: Unpack[TransformersKwargs],
  1053. ) -> tuple | CausalLMOutputWithCrossAttentions:
  1054. r"""
  1055. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1056. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1057. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1058. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1059. Example:
  1060. ```python
  1061. >>> from transformers import AutoTokenizer, BartForCausalLM
  1062. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
  1063. >>> model = BartForCausalLM.from_pretrained("facebook/bart-base")
  1064. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1065. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1066. >>> outputs = model(**inputs)
  1067. >>> logits = outputs.logits
  1068. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1069. >>> list(logits.shape) == expected_shape
  1070. True
  1071. ```"""
  1072. outputs: BaseModelOutputWithPastAndCrossAttentions = self.model.decoder(
  1073. input_ids=input_ids,
  1074. attention_mask=attention_mask,
  1075. encoder_hidden_states=encoder_hidden_states,
  1076. encoder_attention_mask=encoder_attention_mask,
  1077. past_key_values=past_key_values,
  1078. inputs_embeds=inputs_embeds,
  1079. use_cache=use_cache,
  1080. **kwargs,
  1081. )
  1082. hidden_states = outputs[0]
  1083. # Only compute necessary logits
  1084. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1085. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1086. loss = None
  1087. if labels is not None:
  1088. labels = labels.to(logits.device)
  1089. loss_fct = CrossEntropyLoss()
  1090. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1091. return CausalLMOutputWithCrossAttentions(
  1092. loss=loss,
  1093. logits=logits,
  1094. past_key_values=outputs.past_key_values,
  1095. hidden_states=outputs.hidden_states,
  1096. attentions=outputs.attentions,
  1097. cross_attentions=outputs.cross_attentions,
  1098. )
  1099. __all__ = [
  1100. "BartForCausalLM",
  1101. "BartForConditionalGeneration",
  1102. "BartForQuestionAnswering",
  1103. "BartForSequenceClassification",
  1104. "BartModel",
  1105. "BartPreTrainedModel",
  1106. "BartPretrainedModel",
  1107. "PretrainedBartModel",
  1108. ]