modeling_mvp.py 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630
  1. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch MVP model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  22. from ...generation import GenerationMixin
  23. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPastAndCrossAttentions,
  28. CausalLMOutputWithCrossAttentions,
  29. Seq2SeqLMOutput,
  30. Seq2SeqModelOutput,
  31. Seq2SeqQuestionAnsweringModelOutput,
  32. Seq2SeqSequenceClassifierOutput,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...utils import auto_docstring, logging, torch_compilable_check
  36. from .configuration_mvp import MvpConfig
  37. logger = logging.get_logger(__name__)
  38. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  39. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  40. """
  41. Shift input ids one token to the right.
  42. """
  43. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  44. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  45. shifted_input_ids[:, 0] = decoder_start_token_id
  46. if pad_token_id is None:
  47. raise ValueError("self.model.config.pad_token_id has to be defined.")
  48. # replace possible -100 values in labels by `pad_token_id`
  49. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  50. return shifted_input_ids
  51. # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->Mvp
  52. class MvpLearnedPositionalEmbedding(nn.Embedding):
  53. """
  54. This module learns positional embeddings up to a fixed maximum size.
  55. """
  56. def __init__(self, num_embeddings: int, embedding_dim: int):
  57. # Mvp is set up so that if padding_idx is specified then offset the embedding ids by 2
  58. # and adjust num_embeddings appropriately. Other models don't have this hack
  59. self.offset = 2
  60. super().__init__(num_embeddings + self.offset, embedding_dim)
  61. def forward(
  62. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None
  63. ):
  64. """`input_ids' shape is expected to be [bsz x seqlen]."""
  65. if position_ids is None:
  66. bsz, seq_len = input_ids.shape[:2]
  67. position_ids = torch.arange(
  68. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  69. ).expand(bsz, -1)
  70. else:
  71. position_ids = position_ids.unsqueeze(0)
  72. return super().forward(position_ids + self.offset)
  73. class MvpAttention(nn.Module):
  74. """Multi-headed attention from 'Attention Is All You Need' paper"""
  75. def __init__(
  76. self,
  77. embed_dim: int,
  78. num_heads: int,
  79. dropout: float | None = 0.0,
  80. is_decoder: bool | None = False,
  81. bias: bool | None = True,
  82. layer_idx: bool | None = None,
  83. ):
  84. super().__init__()
  85. self.embed_dim = embed_dim
  86. self.num_heads = num_heads
  87. self.dropout = dropout
  88. self.head_dim = embed_dim // num_heads
  89. if (self.head_dim * num_heads) != self.embed_dim:
  90. raise ValueError(
  91. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  92. f" and `num_heads`: {num_heads})."
  93. )
  94. self.scaling = self.head_dim**-0.5
  95. self.is_decoder = is_decoder
  96. self.layer_idx = layer_idx
  97. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  98. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  99. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  100. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  101. def forward(
  102. self,
  103. hidden_states: torch.Tensor,
  104. key_value_states: torch.Tensor | None = None,
  105. past_key_values: Cache | None = None,
  106. attention_mask: torch.Tensor | None = None,
  107. attn_prompt: torch.Tensor | None = None,
  108. output_attentions: bool = False,
  109. **kwargs,
  110. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  111. """Input shape: Batch x Time x Channel"""
  112. # if key_value_states are provided this layer is used as a cross-attention layer
  113. # for the decoder
  114. is_cross_attention = key_value_states is not None
  115. bsz, tgt_len, _ = hidden_states.size()
  116. # get query proj
  117. query_states = self.q_proj(hidden_states) * self.scaling
  118. is_updated = False
  119. if past_key_values is not None:
  120. if isinstance(past_key_values, EncoderDecoderCache):
  121. is_updated = past_key_values.is_updated.get(self.layer_idx)
  122. if is_cross_attention:
  123. # after the first generated id, we can subsequently re-use all key/value_states from cache
  124. curr_past_key_values = past_key_values.cross_attention_cache
  125. else:
  126. curr_past_key_values = past_key_values.self_attention_cache
  127. else:
  128. curr_past_key_values = past_key_values
  129. current_states = key_value_states if is_cross_attention else hidden_states
  130. if is_cross_attention and past_key_values is not None and is_updated:
  131. # reuse k,v, cross_attentions
  132. key_states = curr_past_key_values.layers[self.layer_idx].keys
  133. value_states = curr_past_key_values.layers[self.layer_idx].values
  134. else:
  135. key_states = self.k_proj(current_states)
  136. value_states = self.v_proj(current_states)
  137. key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  138. value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  139. if past_key_values is not None:
  140. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  141. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  142. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  143. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  144. past_key_values.is_updated[self.layer_idx] = True
  145. if attn_prompt is not None:
  146. key_states = torch.cat([attn_prompt[0].expand(bsz, -1, -1, -1), key_states], dim=2)
  147. value_states = torch.cat([attn_prompt[1].expand(bsz, -1, -1, -1), value_states], dim=2)
  148. if attention_mask is not None:
  149. prompt_mask = torch.zeros(bsz, 1, tgt_len, attn_prompt[0].size(1)).to(attention_mask.device)
  150. attention_mask = torch.cat([prompt_mask, attention_mask], dim=(-1))
  151. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  152. query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  153. query_states = query_states.reshape(*proj_shape)
  154. key_states = key_states.reshape(*proj_shape)
  155. value_states = value_states.reshape(*proj_shape)
  156. src_len = key_states.size(1)
  157. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  158. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  159. raise ValueError(
  160. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  161. f" {attn_weights.size()}"
  162. )
  163. if attention_mask is not None:
  164. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  165. raise ValueError(
  166. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  167. )
  168. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  169. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  170. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  171. if output_attentions:
  172. # this operation is a bit awkward, but it's required to
  173. # make sure that attn_weights keeps its gradient.
  174. # In order to do so, attn_weights have to be reshaped
  175. # twice and have to be reused in the following
  176. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  177. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  178. else:
  179. attn_weights_reshaped = None
  180. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  181. attn_output = torch.bmm(attn_probs, value_states)
  182. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  183. raise ValueError(
  184. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  185. f" {attn_output.size()}"
  186. )
  187. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  188. attn_output = attn_output.transpose(1, 2)
  189. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  190. # partitioned across GPUs when using tensor-parallelism.
  191. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  192. attn_output = self.out_proj(attn_output)
  193. return attn_output, attn_weights_reshaped
  194. class MvpEncoderLayer(GradientCheckpointingLayer):
  195. def __init__(self, config: MvpConfig):
  196. super().__init__()
  197. self.embed_dim = config.d_model
  198. self.self_attn = MvpAttention(
  199. embed_dim=self.embed_dim,
  200. num_heads=config.encoder_attention_heads,
  201. dropout=config.attention_dropout,
  202. )
  203. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  204. self.dropout = config.dropout
  205. self.activation_fn = ACT2FN[config.activation_function]
  206. self.activation_dropout = config.activation_dropout
  207. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  208. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  209. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  210. def forward(
  211. self,
  212. hidden_states: torch.FloatTensor,
  213. attention_mask: torch.FloatTensor,
  214. self_attn_prompt: torch.FloatTensor,
  215. output_attentions: bool | None = False,
  216. ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]:
  217. """
  218. Args:
  219. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  220. attention_mask (`torch.FloatTensor`): attention mask of size
  221. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  222. self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape
  223. `(2, encoder_attention_heads, pro_len, head_dim)`.
  224. output_attentions (`bool`, *optional*):
  225. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  226. returned tensors for more detail.
  227. """
  228. residual = hidden_states
  229. hidden_states, attn_weights = self.self_attn(
  230. hidden_states=hidden_states,
  231. attention_mask=attention_mask,
  232. attn_prompt=self_attn_prompt,
  233. output_attentions=output_attentions,
  234. )
  235. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  236. hidden_states = residual + hidden_states
  237. hidden_states = self.self_attn_layer_norm(hidden_states)
  238. residual = hidden_states
  239. hidden_states = self.activation_fn(self.fc1(hidden_states))
  240. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  241. hidden_states = self.fc2(hidden_states)
  242. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  243. hidden_states = residual + hidden_states
  244. hidden_states = self.final_layer_norm(hidden_states)
  245. if hidden_states.dtype == torch.float16 and not torch.isfinite(hidden_states).all():
  246. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  247. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  248. return hidden_states, attn_weights
  249. class MvpDecoderLayer(GradientCheckpointingLayer):
  250. def __init__(self, config: MvpConfig, layer_idx=None):
  251. super().__init__()
  252. self.embed_dim = config.d_model
  253. self.self_attn = MvpAttention(
  254. embed_dim=self.embed_dim,
  255. num_heads=config.decoder_attention_heads,
  256. dropout=config.attention_dropout,
  257. is_decoder=True,
  258. layer_idx=layer_idx,
  259. )
  260. self.dropout = config.dropout
  261. self.activation_fn = ACT2FN[config.activation_function]
  262. self.activation_dropout = config.activation_dropout
  263. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  264. self.encoder_attn = MvpAttention(
  265. self.embed_dim,
  266. config.decoder_attention_heads,
  267. dropout=config.attention_dropout,
  268. is_decoder=True,
  269. layer_idx=layer_idx,
  270. )
  271. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  272. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  273. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  274. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  275. def forward(
  276. self,
  277. hidden_states: torch.Tensor,
  278. attention_mask: torch.Tensor | None = None,
  279. encoder_hidden_states: torch.Tensor | None = None,
  280. encoder_attention_mask: torch.Tensor | None = None,
  281. self_attn_prompt: torch.Tensor | None = None,
  282. cross_attn_prompt: torch.Tensor | None = None,
  283. past_key_values: Cache | None = None,
  284. output_attentions: bool | None = False,
  285. use_cache: bool | None = True,
  286. **kwargs,
  287. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  288. """
  289. Args:
  290. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  291. attention_mask (`torch.FloatTensor`): attention mask of size
  292. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  293. encoder_hidden_states (`torch.FloatTensor`):
  294. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  295. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  296. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  297. self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape
  298. `(2, decoder_attention_heads, pro_len, head_dim)`.
  299. cross_attn_prompt (`torch.FloatTensor`): prompt of cross attention of shape
  300. `(2, decoder_attention_heads, pro_len, head_dim)`.
  301. past_key_values (`Cache`): cached past key and value projection states
  302. output_attentions (`bool`, *optional*):
  303. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  304. returned tensors for more detail.
  305. """
  306. residual = hidden_states
  307. # Self Attention
  308. hidden_states, self_attn_weights = self.self_attn(
  309. hidden_states=hidden_states,
  310. past_key_values=past_key_values,
  311. attention_mask=attention_mask,
  312. attn_prompt=self_attn_prompt,
  313. output_attentions=output_attentions,
  314. )
  315. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  316. hidden_states = residual + hidden_states
  317. hidden_states = self.self_attn_layer_norm(hidden_states)
  318. # Cross-Attention Block
  319. cross_attn_weights = None
  320. if encoder_hidden_states is not None:
  321. residual = hidden_states
  322. hidden_states, cross_attn_weights = self.encoder_attn(
  323. hidden_states=hidden_states,
  324. key_value_states=encoder_hidden_states,
  325. attention_mask=encoder_attention_mask,
  326. attn_prompt=cross_attn_prompt,
  327. past_key_values=past_key_values,
  328. output_attentions=output_attentions,
  329. )
  330. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  331. hidden_states = residual + hidden_states
  332. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  333. # Fully Connected
  334. residual = hidden_states
  335. hidden_states = self.activation_fn(self.fc1(hidden_states))
  336. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  337. hidden_states = self.fc2(hidden_states)
  338. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  339. hidden_states = residual + hidden_states
  340. hidden_states = self.final_layer_norm(hidden_states)
  341. outputs = (hidden_states,)
  342. if output_attentions:
  343. outputs += (self_attn_weights, cross_attn_weights)
  344. return outputs
  345. # Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->MVP
  346. class MvpClassificationHead(nn.Module):
  347. """Head for sentence-level classification tasks."""
  348. def __init__(
  349. self,
  350. input_dim: int,
  351. inner_dim: int,
  352. num_classes: int,
  353. pooler_dropout: float,
  354. ):
  355. super().__init__()
  356. self.dense = nn.Linear(input_dim, inner_dim)
  357. self.dropout = nn.Dropout(p=pooler_dropout)
  358. self.out_proj = nn.Linear(inner_dim, num_classes)
  359. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  360. hidden_states = self.dropout(hidden_states)
  361. hidden_states = self.dense(hidden_states)
  362. hidden_states = torch.tanh(hidden_states)
  363. hidden_states = self.dropout(hidden_states)
  364. hidden_states = self.out_proj(hidden_states)
  365. return hidden_states
  366. class MvpPrompt(nn.Module):
  367. """Layer-wise prompt for encoder or decoder."""
  368. def __init__(self, config, num_layers, num_heads):
  369. super().__init__()
  370. self.prompt_length = config.prompt_length
  371. self.num_layers = num_layers
  372. self.num_heads = num_heads
  373. self.head_dim = config.d_model // num_heads
  374. self.dropout = nn.Dropout(p=config.dropout)
  375. self.prompt_embedding = nn.Embedding(config.prompt_length, config.d_model)
  376. self.prompt_trans = nn.Sequential(
  377. nn.Linear(config.d_model, config.prompt_mid_dim),
  378. nn.GELU(),
  379. nn.Linear(config.prompt_mid_dim, num_layers * 2 * config.d_model),
  380. )
  381. def forward(self, prompt_ids: torch.Tensor) -> tuple[torch.Tensor]:
  382. prompt = self.prompt_trans(self.prompt_embedding(prompt_ids))
  383. prompt = prompt.view(self.prompt_length, self.num_layers * 2, self.num_heads, self.head_dim)
  384. prompt = self.dropout(prompt)
  385. prompt = prompt.permute([1, 2, 0, 3]).split(2)
  386. return prompt
  387. @auto_docstring
  388. class MvpPreTrainedModel(PreTrainedModel):
  389. config: MvpConfig
  390. base_model_prefix = "model"
  391. supports_gradient_checkpointing = True
  392. def _init_weights(self, module):
  393. super()._init_weights(module)
  394. if isinstance(module, MvpForConditionalGeneration):
  395. init.zeros_(module.final_logits_bias)
  396. @property
  397. def dummy_inputs(self):
  398. pad_token = self.config.pad_token_id
  399. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  400. dummy_inputs = {
  401. "attention_mask": input_ids.ne(pad_token),
  402. "input_ids": input_ids,
  403. }
  404. return dummy_inputs
  405. class MvpEncoder(MvpPreTrainedModel):
  406. """
  407. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  408. [`MvpEncoderLayer`].
  409. Args:
  410. config: MvpConfig
  411. embed_tokens (nn.Embedding): output embedding
  412. use_prompt (bool): whether to use prompt
  413. """
  414. def __init__(self, config: MvpConfig, embed_tokens: nn.Embedding | None = None, use_prompt: bool | None = False):
  415. super().__init__(config)
  416. self.dropout = config.dropout
  417. self.layerdrop = config.encoder_layerdrop
  418. embed_dim = config.d_model
  419. self.padding_idx = config.pad_token_id
  420. self.max_source_positions = config.max_position_embeddings
  421. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  422. self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
  423. self.embed_positions = MvpLearnedPositionalEmbedding(
  424. config.max_position_embeddings,
  425. embed_dim,
  426. )
  427. self.layers = nn.ModuleList([MvpEncoderLayer(config) for _ in range(config.encoder_layers)])
  428. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  429. self.use_prompt = use_prompt
  430. if use_prompt:
  431. self.prompt_length = config.prompt_length
  432. self.self_attn_prompt = MvpPrompt(
  433. config,
  434. config.encoder_layers,
  435. config.encoder_attention_heads,
  436. )
  437. self.gradient_checkpointing = False
  438. # Initialize weights and apply final processing
  439. self.post_init()
  440. def forward(
  441. self,
  442. input_ids: torch.LongTensor | None = None,
  443. attention_mask: torch.Tensor | None = None,
  444. inputs_embeds: torch.FloatTensor | None = None,
  445. output_attentions: bool | None = None,
  446. output_hidden_states: bool | None = None,
  447. return_dict: bool | None = None,
  448. **kwargs,
  449. ) -> tuple | BaseModelOutput:
  450. r"""
  451. Args:
  452. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  453. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  454. provide it.
  455. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  456. [`PreTrainedTokenizer.__call__`] for details.
  457. [What are input IDs?](../glossary#input-ids)
  458. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  459. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  460. - 1 for tokens that are **not masked**,
  461. - 0 for tokens that are **masked**.
  462. [What are attention masks?](../glossary#attention-mask)
  463. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  464. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  465. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  466. than the model's internal embedding lookup matrix.
  467. output_attentions (`bool`, *optional*):
  468. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  469. returned tensors for more detail.
  470. output_hidden_states (`bool`, *optional*):
  471. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  472. for more detail.
  473. return_dict (`bool`, *optional*):
  474. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  475. """
  476. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  477. output_hidden_states = (
  478. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  479. )
  480. return_dict = return_dict if return_dict is not None else self.config.return_dict
  481. # retrieve input_ids and inputs_embeds
  482. if input_ids is not None and inputs_embeds is not None:
  483. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  484. elif input_ids is not None:
  485. input = input_ids
  486. input_shape = input.shape
  487. input_ids = input_ids.view(-1, input_shape[-1])
  488. elif inputs_embeds is not None:
  489. input_shape = inputs_embeds.size()[:-1]
  490. input = inputs_embeds[:, :, -1]
  491. else:
  492. raise ValueError("You have to specify either input_ids or inputs_embeds")
  493. if inputs_embeds is None:
  494. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  495. embed_pos = self.embed_positions(input)
  496. hidden_states = inputs_embeds + embed_pos
  497. hidden_states = self.layernorm_embedding(hidden_states)
  498. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  499. # layer-wise prompt
  500. if self.use_prompt:
  501. prompt_ids = torch.arange(self.prompt_length).to(self.device)
  502. self_attn_prompt = self.self_attn_prompt(prompt_ids)
  503. # expand attention_mask
  504. if attention_mask is not None:
  505. attention_mask = create_bidirectional_mask(
  506. config=self.config,
  507. inputs_embeds=hidden_states,
  508. attention_mask=attention_mask,
  509. )
  510. encoder_states = () if output_hidden_states else None
  511. all_attentions = () if output_attentions else None
  512. for idx, encoder_layer in enumerate(self.layers):
  513. if output_hidden_states:
  514. encoder_states = encoder_states + (hidden_states,)
  515. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  516. to_drop = False
  517. if self.training:
  518. dropout_probability = torch.rand([])
  519. if dropout_probability < self.layerdrop: # skip the layer
  520. to_drop = True
  521. if to_drop:
  522. layer_outputs = (None, None)
  523. else:
  524. layer_outputs = encoder_layer(
  525. hidden_states,
  526. attention_mask,
  527. self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None),
  528. output_attentions=output_attentions,
  529. )
  530. hidden_states = layer_outputs[0]
  531. if output_attentions:
  532. all_attentions = all_attentions + (layer_outputs[1],)
  533. if output_hidden_states:
  534. encoder_states = encoder_states + (hidden_states,)
  535. if not return_dict:
  536. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  537. return BaseModelOutput(
  538. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  539. )
  540. class MvpDecoder(MvpPreTrainedModel):
  541. """
  542. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MvpDecoderLayer`]
  543. Args:
  544. config: MvpConfig
  545. embed_tokens (nn.Embedding): output embedding
  546. use_prompt (bool): whether to use prompt
  547. """
  548. def __init__(self, config: MvpConfig, use_prompt: bool | None = False):
  549. super().__init__(config)
  550. self.dropout = config.dropout
  551. self.layerdrop = config.decoder_layerdrop
  552. self.padding_idx = config.pad_token_id
  553. self.max_target_positions = config.max_position_embeddings
  554. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  555. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  556. self.embed_positions = MvpLearnedPositionalEmbedding(
  557. config.max_position_embeddings,
  558. config.d_model,
  559. )
  560. self.layers = nn.ModuleList([MvpDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  561. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  562. self.use_prompt = use_prompt
  563. if use_prompt:
  564. self.prompt_length = config.prompt_length
  565. self.self_attn_prompt = MvpPrompt(
  566. config,
  567. config.decoder_layers,
  568. config.decoder_attention_heads,
  569. )
  570. self.cross_attn_prompt = MvpPrompt(
  571. config,
  572. config.decoder_layers,
  573. config.decoder_attention_heads,
  574. )
  575. self.gradient_checkpointing = False
  576. # Initialize weights and apply final processing
  577. self.post_init()
  578. def forward(
  579. self,
  580. input_ids: torch.LongTensor | None = None,
  581. attention_mask: torch.Tensor | None = None,
  582. encoder_hidden_states: torch.FloatTensor | None = None,
  583. encoder_attention_mask: torch.LongTensor | None = None,
  584. past_key_values: Cache | None = None,
  585. inputs_embeds: torch.FloatTensor | None = None,
  586. use_cache: bool | None = None,
  587. output_attentions: bool | None = None,
  588. output_hidden_states: bool | None = None,
  589. return_dict: bool | None = None,
  590. **kwargs,
  591. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  592. r"""
  593. Args:
  594. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  595. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  596. provide it.
  597. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  598. [`PreTrainedTokenizer.__call__`] for details.
  599. [What are input IDs?](../glossary#input-ids)
  600. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  601. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  602. - 1 for tokens that are **not masked**,
  603. - 0 for tokens that are **masked**.
  604. [What are attention masks?](../glossary#attention-mask)
  605. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  606. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  607. of the decoder.
  608. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  609. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  610. selected in `[0, 1]`:
  611. - 1 for tokens that are **not masked**,
  612. - 0 for tokens that are **masked**.
  613. [What are attention masks?](../glossary#attention-mask)
  614. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  615. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  616. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  617. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  618. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  619. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  620. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  621. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  622. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  623. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  624. than the model's internal embedding lookup matrix.
  625. output_attentions (`bool`, *optional*):
  626. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  627. returned tensors for more detail.
  628. output_hidden_states (`bool`, *optional*):
  629. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  630. for more detail.
  631. return_dict (`bool`, *optional*):
  632. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  633. """
  634. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  635. output_hidden_states = (
  636. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  637. )
  638. use_cache = use_cache if use_cache is not None else self.config.use_cache
  639. return_dict = return_dict if return_dict is not None else self.config.return_dict
  640. # retrieve input_ids and inputs_embeds
  641. if input_ids is not None and inputs_embeds is not None:
  642. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  643. elif input_ids is not None:
  644. input = input_ids
  645. input_shape = input_ids.shape
  646. input_ids = input_ids.view(-1, input_shape[-1])
  647. elif inputs_embeds is not None:
  648. input_shape = inputs_embeds.size()[:-1]
  649. input = inputs_embeds[:, :, -1]
  650. else:
  651. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  652. if inputs_embeds is None:
  653. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  654. if self.gradient_checkpointing and self.training:
  655. if use_cache:
  656. logger.warning_once(
  657. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  658. )
  659. use_cache = False
  660. if use_cache and past_key_values is None:
  661. past_key_values = (
  662. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  663. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  664. else DynamicCache(config=self.config)
  665. )
  666. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  667. attention_mask = create_causal_mask(
  668. config=self.config,
  669. inputs_embeds=inputs_embeds,
  670. attention_mask=attention_mask,
  671. past_key_values=past_key_values,
  672. )
  673. # expand encoder attention mask
  674. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  675. encoder_attention_mask = create_bidirectional_mask(
  676. config=self.config,
  677. inputs_embeds=inputs_embeds,
  678. attention_mask=encoder_attention_mask,
  679. encoder_hidden_states=encoder_hidden_states,
  680. )
  681. # embed positions
  682. positions = self.embed_positions(input, past_key_values_length)
  683. hidden_states = inputs_embeds + positions
  684. hidden_states = self.layernorm_embedding(hidden_states)
  685. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  686. # layer-wise prompt
  687. if self.use_prompt:
  688. prompt_ids = torch.arange(self.prompt_length).to(self.device)
  689. self_attn_prompt = self.self_attn_prompt(prompt_ids)
  690. cross_attn_prompt = self.cross_attn_prompt(prompt_ids)
  691. # decoder layers
  692. all_hidden_states = () if output_hidden_states else None
  693. all_self_attns = () if output_attentions else None
  694. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  695. for idx, decoder_layer in enumerate(self.layers):
  696. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  697. if output_hidden_states:
  698. all_hidden_states += (hidden_states,)
  699. if self.training:
  700. dropout_probability = torch.rand([])
  701. if dropout_probability < self.layerdrop:
  702. continue
  703. layer_outputs = decoder_layer(
  704. hidden_states,
  705. attention_mask,
  706. encoder_hidden_states, # as positional argument for gradient checkpointing
  707. encoder_attention_mask=encoder_attention_mask,
  708. self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None),
  709. cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None),
  710. past_key_values=past_key_values,
  711. output_attentions=output_attentions,
  712. use_cache=use_cache,
  713. )
  714. hidden_states = layer_outputs[0]
  715. if output_attentions:
  716. all_self_attns += (layer_outputs[1],)
  717. if encoder_hidden_states is not None:
  718. all_cross_attentions += (layer_outputs[2],)
  719. # add hidden states from the last decoder layer
  720. if output_hidden_states:
  721. all_hidden_states += (hidden_states,)
  722. if not return_dict:
  723. return tuple(
  724. v
  725. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  726. if v is not None
  727. )
  728. return BaseModelOutputWithPastAndCrossAttentions(
  729. last_hidden_state=hidden_states,
  730. past_key_values=past_key_values,
  731. hidden_states=all_hidden_states,
  732. attentions=all_self_attns,
  733. cross_attentions=all_cross_attentions,
  734. )
  735. @auto_docstring
  736. class MvpModel(MvpPreTrainedModel):
  737. _keys_to_ignore_on_load_unexpected = ["final_logits_bias"]
  738. _tied_weights_keys = {
  739. "encoder.embed_tokens.weight": "shared.weight",
  740. "decoder.embed_tokens.weight": "shared.weight",
  741. }
  742. def __init__(self, config: MvpConfig):
  743. super().__init__(config)
  744. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  745. self.use_prompt = config.use_prompt
  746. self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  747. self.encoder = MvpEncoder(config, config.use_prompt)
  748. self.decoder = MvpDecoder(config, config.use_prompt)
  749. # Initialize weights and apply final processing
  750. self.post_init()
  751. def get_input_embeddings(self):
  752. return self.shared
  753. def set_input_embeddings(self, value):
  754. self.shared = value
  755. self.encoder.embed_tokens = self.shared
  756. self.decoder.embed_tokens = self.shared
  757. def set_lightweight_tuning(self):
  758. assert self.use_prompt, "If you want to use lightweight tuning, make sure that `use_prompt=True`."
  759. self.requires_grad_(False)
  760. self.encoder.self_attn_prompt.requires_grad_(True)
  761. self.decoder.self_attn_prompt.requires_grad_(True)
  762. self.decoder.cross_attn_prompt.requires_grad_(True)
  763. @auto_docstring
  764. def forward(
  765. self,
  766. input_ids: torch.LongTensor | None = None,
  767. attention_mask: torch.Tensor | None = None,
  768. decoder_input_ids: torch.LongTensor | None = None,
  769. decoder_attention_mask: torch.LongTensor | None = None,
  770. encoder_outputs: list[torch.FloatTensor] | None = None,
  771. past_key_values: Cache | None = None,
  772. inputs_embeds: torch.FloatTensor | None = None,
  773. decoder_inputs_embeds: torch.FloatTensor | None = None,
  774. use_cache: bool | None = None,
  775. output_attentions: bool | None = None,
  776. output_hidden_states: bool | None = None,
  777. return_dict: bool | None = None,
  778. **kwargs,
  779. ) -> tuple | Seq2SeqModelOutput:
  780. r"""
  781. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  782. Indices of decoder input sequence tokens in the vocabulary.
  783. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  784. [`PreTrainedTokenizer.__call__`] for details.
  785. [What are decoder input IDs?](../glossary#decoder-input-ids)
  786. Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  787. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  788. For translation and summarization training, `decoder_input_ids` should be provided. If no
  789. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  790. for denoising pre-training following the paper.
  791. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  792. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  793. be used by default.
  794. If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
  795. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  796. information on the default strategy.
  797. """
  798. # different to other models, Mvp automatically creates decoder_input_ids from
  799. # input_ids if no decoder_input_ids are provided
  800. if decoder_input_ids is None and decoder_inputs_embeds is None:
  801. if input_ids is None:
  802. raise ValueError(
  803. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  804. "passed, `input_ids` cannot be `None`. Please pass either "
  805. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  806. )
  807. decoder_input_ids = shift_tokens_right(
  808. input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
  809. )
  810. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  811. output_hidden_states = (
  812. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  813. )
  814. use_cache = use_cache if use_cache is not None else self.config.use_cache
  815. return_dict = return_dict if return_dict is not None else self.config.return_dict
  816. if encoder_outputs is None:
  817. encoder_outputs = self.encoder(
  818. input_ids=input_ids,
  819. attention_mask=attention_mask,
  820. inputs_embeds=inputs_embeds,
  821. output_attentions=output_attentions,
  822. output_hidden_states=output_hidden_states,
  823. return_dict=return_dict,
  824. )
  825. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  826. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  827. encoder_outputs = BaseModelOutput(
  828. last_hidden_state=encoder_outputs[0],
  829. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  830. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  831. )
  832. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  833. decoder_outputs = self.decoder(
  834. input_ids=decoder_input_ids,
  835. attention_mask=decoder_attention_mask,
  836. encoder_hidden_states=encoder_outputs[0],
  837. encoder_attention_mask=attention_mask,
  838. past_key_values=past_key_values,
  839. inputs_embeds=decoder_inputs_embeds,
  840. use_cache=use_cache,
  841. output_attentions=output_attentions,
  842. output_hidden_states=output_hidden_states,
  843. return_dict=return_dict,
  844. )
  845. if not return_dict:
  846. return decoder_outputs + encoder_outputs
  847. return Seq2SeqModelOutput(
  848. last_hidden_state=decoder_outputs.last_hidden_state,
  849. past_key_values=decoder_outputs.past_key_values,
  850. decoder_hidden_states=decoder_outputs.hidden_states,
  851. decoder_attentions=decoder_outputs.attentions,
  852. cross_attentions=decoder_outputs.cross_attentions,
  853. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  854. encoder_hidden_states=encoder_outputs.hidden_states,
  855. encoder_attentions=encoder_outputs.attentions,
  856. )
  857. @auto_docstring(
  858. custom_intro="""
  859. The MVP Model with a language modeling head. Can be used for various text generation tasks.
  860. """
  861. )
  862. class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin):
  863. _tied_weights_keys = {
  864. "lm_head.weight": "model.shared.weight",
  865. }
  866. def __init__(self, config: MvpConfig):
  867. super().__init__(config)
  868. self.model = MvpModel(config)
  869. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  870. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  871. # Initialize weights and apply final processing
  872. self.post_init()
  873. def resize_token_embeddings(
  874. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  875. ) -> nn.Embedding:
  876. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  877. self._resize_final_logits_bias(new_num_tokens)
  878. return new_embeddings
  879. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  880. old_num_tokens = self.final_logits_bias.shape[-1]
  881. if new_num_tokens <= old_num_tokens:
  882. new_bias = self.final_logits_bias[:, :new_num_tokens]
  883. else:
  884. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  885. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  886. self.register_buffer("final_logits_bias", new_bias)
  887. def set_lightweight_tuning(self):
  888. self.model.set_lightweight_tuning()
  889. self.lm_head.requires_grad_(False)
  890. @auto_docstring
  891. def forward(
  892. self,
  893. input_ids: torch.LongTensor | None = None,
  894. attention_mask: torch.Tensor | None = None,
  895. decoder_input_ids: torch.LongTensor | None = None,
  896. decoder_attention_mask: torch.LongTensor | None = None,
  897. encoder_outputs: list[torch.FloatTensor] | None = None,
  898. past_key_values: Cache | None = None,
  899. inputs_embeds: torch.FloatTensor | None = None,
  900. decoder_inputs_embeds: torch.FloatTensor | None = None,
  901. labels: torch.LongTensor | None = None,
  902. use_cache: bool | None = None,
  903. output_attentions: bool | None = None,
  904. output_hidden_states: bool | None = None,
  905. return_dict: bool | None = None,
  906. **kwargs,
  907. ) -> tuple | Seq2SeqLMOutput:
  908. r"""
  909. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  910. Indices of decoder input sequence tokens in the vocabulary.
  911. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  912. [`PreTrainedTokenizer.__call__`] for details.
  913. [What are decoder input IDs?](../glossary#decoder-input-ids)
  914. Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  915. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  916. For translation and summarization training, `decoder_input_ids` should be provided. If no
  917. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  918. for denoising pre-training following the paper.
  919. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  920. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  921. be used by default.
  922. If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
  923. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  924. information on the default strategy.
  925. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  926. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  927. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  928. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  929. Example of summarization:
  930. Fine-tuning a model
  931. ```python
  932. >>> import torch
  933. >>> from transformers import AutoTokenizer, MvpForConditionalGeneration
  934. >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
  935. >>> model = MvpForConditionalGeneration.from_pretrained("RUCAIBox/mvp")
  936. >>> inputs = tokenizer(
  937. ... "Summarize: You may want to stick it to your boss and leave your job, but don't do it if these are your reasons.",
  938. ... return_tensors="pt",
  939. ... )
  940. >>> labels = tokenizer("Bad Reasons To Quit Your Job", return_tensors="pt")["input_ids"]
  941. >>> loss = model(**inputs, labels=labels).loss
  942. >>> loss.backward()
  943. ```
  944. Inference after the model fine-tuned
  945. ```python
  946. >>> with torch.no_grad():
  947. ... generated_ids = model.generate(**inputs)
  948. >>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  949. ```
  950. """
  951. return_dict = return_dict if return_dict is not None else self.config.return_dict
  952. if labels is not None:
  953. if use_cache:
  954. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  955. use_cache = False
  956. if decoder_input_ids is None and decoder_inputs_embeds is None:
  957. decoder_input_ids = shift_tokens_right(
  958. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  959. )
  960. outputs = self.model(
  961. input_ids,
  962. attention_mask=attention_mask,
  963. decoder_input_ids=decoder_input_ids,
  964. encoder_outputs=encoder_outputs,
  965. decoder_attention_mask=decoder_attention_mask,
  966. past_key_values=past_key_values,
  967. inputs_embeds=inputs_embeds,
  968. decoder_inputs_embeds=decoder_inputs_embeds,
  969. use_cache=use_cache,
  970. output_attentions=output_attentions,
  971. output_hidden_states=output_hidden_states,
  972. return_dict=return_dict,
  973. )
  974. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  975. masked_lm_loss = None
  976. if labels is not None:
  977. loss_fct = CrossEntropyLoss()
  978. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  979. if not return_dict:
  980. output = (lm_logits,) + outputs[1:]
  981. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  982. return Seq2SeqLMOutput(
  983. loss=masked_lm_loss,
  984. logits=lm_logits,
  985. past_key_values=outputs.past_key_values,
  986. decoder_hidden_states=outputs.decoder_hidden_states,
  987. decoder_attentions=outputs.decoder_attentions,
  988. cross_attentions=outputs.cross_attentions,
  989. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  990. encoder_hidden_states=outputs.encoder_hidden_states,
  991. encoder_attentions=outputs.encoder_attentions,
  992. )
  993. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  994. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  995. @auto_docstring(
  996. custom_intro="""
  997. Mvp model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  998. tasks.
  999. """
  1000. )
  1001. class MvpForSequenceClassification(MvpPreTrainedModel):
  1002. def __init__(self, config: MvpConfig, **kwargs):
  1003. super().__init__(config, **kwargs)
  1004. self.model = MvpModel(config)
  1005. self.classification_head = MvpClassificationHead(
  1006. config.d_model,
  1007. config.d_model,
  1008. config.num_labels,
  1009. config.classifier_dropout,
  1010. )
  1011. # Initialize weights and apply final processing
  1012. self.post_init()
  1013. def set_lightweight_tuning(self):
  1014. self.model.set_lightweight_tuning()
  1015. self.classification_head.requires_grad_(False)
  1016. @auto_docstring
  1017. def forward(
  1018. self,
  1019. input_ids: torch.LongTensor | None = None,
  1020. attention_mask: torch.Tensor | None = None,
  1021. decoder_input_ids: torch.LongTensor | None = None,
  1022. decoder_attention_mask: torch.LongTensor | None = None,
  1023. encoder_outputs: list[torch.FloatTensor] | None = None,
  1024. inputs_embeds: torch.FloatTensor | None = None,
  1025. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1026. labels: torch.LongTensor | None = None,
  1027. use_cache: bool | None = None,
  1028. output_attentions: bool | None = None,
  1029. output_hidden_states: bool | None = None,
  1030. return_dict: bool | None = None,
  1031. **kwargs,
  1032. ) -> tuple | Seq2SeqSequenceClassifierOutput:
  1033. r"""
  1034. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1035. Indices of decoder input sequence tokens in the vocabulary.
  1036. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1037. [`PreTrainedTokenizer.__call__`] for details.
  1038. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1039. Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1040. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1041. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1042. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1043. for denoising pre-training following the paper.
  1044. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1045. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1046. be used by default.
  1047. If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
  1048. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1049. information on the default strategy.
  1050. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1051. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1052. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1053. Example of single-label classification:
  1054. Fine-tuning a model on `num_labels` classes
  1055. ```python
  1056. >>> import torch
  1057. >>> from transformers import AutoTokenizer, MvpForSequenceClassification
  1058. >>> num_labels = 2 # for example, this is a binary classification task
  1059. >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
  1060. >>> model = MvpForSequenceClassification.from_pretrained("RUCAIBox/mvp", num_labels=num_labels)
  1061. >>> inputs = tokenizer("Classify: Hello, my dog is cute", return_tensors="pt")
  1062. >>> labels = torch.tensor(1) # the real label for inputs
  1063. >>> loss = model(**inputs, labels=labels).loss
  1064. >>> loss.backward()
  1065. ```
  1066. Inference after the model fine-tuned
  1067. ```python
  1068. >>> with torch.no_grad():
  1069. ... logits = model(**inputs).logits
  1070. >>> predicted_class_id = logits.argmax()
  1071. ```
  1072. """
  1073. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1074. if labels is not None:
  1075. use_cache = False
  1076. if input_ids is None and inputs_embeds is not None:
  1077. raise NotImplementedError(
  1078. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1079. )
  1080. outputs = self.model(
  1081. input_ids,
  1082. attention_mask=attention_mask,
  1083. decoder_input_ids=decoder_input_ids,
  1084. decoder_attention_mask=decoder_attention_mask,
  1085. encoder_outputs=encoder_outputs,
  1086. inputs_embeds=inputs_embeds,
  1087. decoder_inputs_embeds=decoder_inputs_embeds,
  1088. use_cache=use_cache,
  1089. output_attentions=output_attentions,
  1090. output_hidden_states=output_hidden_states,
  1091. return_dict=return_dict,
  1092. )
  1093. hidden_states = outputs[0] # last hidden state
  1094. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  1095. torch_compilable_check(
  1096. torch.unique_consecutive(eos_mask.sum(1)).numel() == 1,
  1097. "All examples must have the same number of <eos> tokens.",
  1098. )
  1099. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  1100. :, -1, :
  1101. ]
  1102. logits = self.classification_head(sentence_representation)
  1103. loss = None
  1104. if labels is not None:
  1105. if self.config.problem_type is None:
  1106. if self.config.num_labels == 1:
  1107. self.config.problem_type = "regression"
  1108. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1109. self.config.problem_type = "single_label_classification"
  1110. else:
  1111. self.config.problem_type = "multi_label_classification"
  1112. if self.config.problem_type == "regression":
  1113. loss_fct = MSELoss()
  1114. if self.config.num_labels == 1:
  1115. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1116. else:
  1117. loss = loss_fct(logits, labels)
  1118. elif self.config.problem_type == "single_label_classification":
  1119. loss_fct = CrossEntropyLoss()
  1120. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1121. elif self.config.problem_type == "multi_label_classification":
  1122. loss_fct = BCEWithLogitsLoss()
  1123. loss = loss_fct(logits, labels)
  1124. if not return_dict:
  1125. output = (logits,) + outputs[1:]
  1126. return ((loss,) + output) if loss is not None else output
  1127. return Seq2SeqSequenceClassifierOutput(
  1128. loss=loss,
  1129. logits=logits,
  1130. past_key_values=outputs.past_key_values,
  1131. decoder_hidden_states=outputs.decoder_hidden_states,
  1132. decoder_attentions=outputs.decoder_attentions,
  1133. cross_attentions=outputs.cross_attentions,
  1134. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1135. encoder_hidden_states=outputs.encoder_hidden_states,
  1136. encoder_attentions=outputs.encoder_attentions,
  1137. )
  1138. @auto_docstring
  1139. class MvpForQuestionAnswering(MvpPreTrainedModel):
  1140. def __init__(self, config):
  1141. super().__init__(config)
  1142. config.num_labels = 2
  1143. self.num_labels = config.num_labels
  1144. self.model = MvpModel(config)
  1145. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1146. # Initialize weights and apply final processing
  1147. self.post_init()
  1148. def set_lightweight_tuning(self):
  1149. self.model.set_lightweight_tuning()
  1150. self.qa_outputs.requires_grad_(False)
  1151. @auto_docstring
  1152. def forward(
  1153. self,
  1154. input_ids: torch.Tensor | None = None,
  1155. attention_mask: torch.Tensor | None = None,
  1156. decoder_input_ids: torch.LongTensor | None = None,
  1157. decoder_attention_mask: torch.LongTensor | None = None,
  1158. encoder_outputs: list[torch.FloatTensor] | None = None,
  1159. start_positions: torch.LongTensor | None = None,
  1160. end_positions: torch.LongTensor | None = None,
  1161. inputs_embeds: torch.FloatTensor | None = None,
  1162. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1163. use_cache: bool | None = None,
  1164. output_attentions: bool | None = None,
  1165. output_hidden_states: bool | None = None,
  1166. return_dict: bool | None = None,
  1167. **kwargs,
  1168. ) -> tuple | Seq2SeqQuestionAnsweringModelOutput:
  1169. r"""
  1170. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1171. Indices of decoder input sequence tokens in the vocabulary.
  1172. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1173. [`PreTrainedTokenizer.__call__`] for details.
  1174. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1175. Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1176. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1177. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1178. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1179. for denoising pre-training following the paper.
  1180. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1181. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1182. be used by default.
  1183. If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
  1184. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1185. information on the default strategy.
  1186. Example:
  1187. Fine-tuning a model for extrative question answering, and our model also supports generative question answering
  1188. using `BartForConditionalGeneration`
  1189. ```python
  1190. >>> import torch
  1191. >>> from transformers import AutoTokenizer, MvpForQuestionAnswering
  1192. >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
  1193. >>> model = MvpForQuestionAnswering.from_pretrained("RUCAIBox/mvp")
  1194. >>> inputs = tokenizer(
  1195. ... "Answer the following question: Who was Jim Henson? [SEP] Jim Henson was a nice puppet",
  1196. ... return_tensors="pt",
  1197. ... )
  1198. >>> target_start_index = torch.tensor([18])
  1199. >>> target_end_index = torch.tensor([19])
  1200. >>> loss = model(**inputs, start_positions=target_start_index, end_positions=target_end_index).loss
  1201. >>> loss.backward()
  1202. ```
  1203. Inference after the model fine-tuned
  1204. ```python
  1205. >>> with torch.no_grad():
  1206. ... outputs = model(**inputs)
  1207. >>> answer_start_index = outputs.start_logits.argmax()
  1208. >>> answer_end_index = outputs.end_logits.argmax()
  1209. >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
  1210. >>> predict_answer = tokenizer.decode(predict_answer_tokens)
  1211. ```
  1212. """
  1213. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1214. if start_positions is not None and end_positions is not None:
  1215. use_cache = False
  1216. outputs = self.model(
  1217. input_ids,
  1218. attention_mask=attention_mask,
  1219. decoder_input_ids=decoder_input_ids,
  1220. decoder_attention_mask=decoder_attention_mask,
  1221. encoder_outputs=encoder_outputs,
  1222. inputs_embeds=inputs_embeds,
  1223. decoder_inputs_embeds=decoder_inputs_embeds,
  1224. use_cache=use_cache,
  1225. output_attentions=output_attentions,
  1226. output_hidden_states=output_hidden_states,
  1227. return_dict=return_dict,
  1228. )
  1229. sequence_output = outputs[0]
  1230. logits = self.qa_outputs(sequence_output)
  1231. start_logits, end_logits = logits.split(1, dim=-1)
  1232. start_logits = start_logits.squeeze(-1).contiguous()
  1233. end_logits = end_logits.squeeze(-1).contiguous()
  1234. total_loss = None
  1235. if start_positions is not None and end_positions is not None:
  1236. # If we are on multi-GPU, split add a dimension
  1237. if len(start_positions.size()) > 1:
  1238. start_positions = start_positions.squeeze(-1)
  1239. if len(end_positions.size()) > 1:
  1240. end_positions = end_positions.squeeze(-1)
  1241. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1242. ignored_index = start_logits.size(1)
  1243. start_positions = start_positions.clamp(0, ignored_index)
  1244. end_positions = end_positions.clamp(0, ignored_index)
  1245. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1246. start_loss = loss_fct(start_logits, start_positions)
  1247. end_loss = loss_fct(end_logits, end_positions)
  1248. total_loss = (start_loss + end_loss) / 2
  1249. if not return_dict:
  1250. output = (
  1251. start_logits,
  1252. end_logits,
  1253. ) + outputs[1:]
  1254. return ((total_loss,) + output) if total_loss is not None else output
  1255. return Seq2SeqQuestionAnsweringModelOutput(
  1256. loss=total_loss,
  1257. start_logits=start_logits,
  1258. end_logits=end_logits,
  1259. past_key_values=outputs.past_key_values,
  1260. decoder_hidden_states=outputs.decoder_hidden_states,
  1261. decoder_attentions=outputs.decoder_attentions,
  1262. cross_attentions=outputs.cross_attentions,
  1263. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1264. encoder_hidden_states=outputs.encoder_hidden_states,
  1265. encoder_attentions=outputs.encoder_attentions,
  1266. )
  1267. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Mvp
  1268. class MvpDecoderWrapper(MvpPreTrainedModel):
  1269. """
  1270. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1271. used in combination with the [`EncoderDecoderModel`] framework.
  1272. """
  1273. def __init__(self, config):
  1274. super().__init__(config)
  1275. self.decoder = MvpDecoder(config)
  1276. self.post_init()
  1277. def forward(self, *args, **kwargs):
  1278. return self.decoder(*args, **kwargs)
  1279. class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin):
  1280. _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"}
  1281. def __init__(self, config):
  1282. config.is_decoder = True
  1283. config.is_encoder_decoder = False
  1284. super().__init__(config)
  1285. self.model = MvpDecoderWrapper(config)
  1286. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1287. # Initialize weights and apply final processing
  1288. self.post_init()
  1289. def get_input_embeddings(self):
  1290. return self.model.decoder.embed_tokens
  1291. def set_input_embeddings(self, value):
  1292. self.model.decoder.embed_tokens = value
  1293. def set_lightweight_tuning(self):
  1294. self.model.set_lightweight_tuning()
  1295. self.lm_head.requires_grad_(False)
  1296. @auto_docstring
  1297. def forward(
  1298. self,
  1299. input_ids: torch.LongTensor | None = None,
  1300. attention_mask: torch.Tensor | None = None,
  1301. encoder_hidden_states: torch.FloatTensor | None = None,
  1302. encoder_attention_mask: torch.FloatTensor | None = None,
  1303. past_key_values: Cache | None = None,
  1304. inputs_embeds: torch.FloatTensor | None = None,
  1305. labels: torch.LongTensor | None = None,
  1306. use_cache: bool | None = None,
  1307. output_attentions: bool | None = None,
  1308. output_hidden_states: bool | None = None,
  1309. return_dict: bool | None = None,
  1310. logits_to_keep: int | torch.Tensor = 0,
  1311. **kwargs,
  1312. ) -> tuple | CausalLMOutputWithCrossAttentions:
  1313. r"""
  1314. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1315. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1316. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1317. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1318. Example:
  1319. ```python
  1320. >>> from transformers import AutoTokenizer, MvpForCausalLM
  1321. >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
  1322. >>> model = MvpForCausalLM.from_pretrained("RUCAIBox/mvp")
  1323. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1324. >>> outputs = model(**inputs)
  1325. >>> logits = outputs.logits
  1326. >>> list(logits.shape)
  1327. [1, 8, 50267]
  1328. ```"""
  1329. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1330. output_hidden_states = (
  1331. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1332. )
  1333. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1334. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1335. outputs = self.model.decoder(
  1336. input_ids=input_ids,
  1337. attention_mask=attention_mask,
  1338. encoder_hidden_states=encoder_hidden_states,
  1339. encoder_attention_mask=encoder_attention_mask,
  1340. past_key_values=past_key_values,
  1341. inputs_embeds=inputs_embeds,
  1342. use_cache=use_cache,
  1343. output_attentions=output_attentions,
  1344. output_hidden_states=output_hidden_states,
  1345. return_dict=return_dict,
  1346. )
  1347. hidden_states = outputs[0]
  1348. # Only compute necessary logits
  1349. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1350. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1351. loss = None
  1352. if labels is not None:
  1353. loss_fct = CrossEntropyLoss()
  1354. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1355. if not return_dict:
  1356. output = (logits,) + outputs[1:]
  1357. return (loss,) + output if loss is not None else output
  1358. return CausalLMOutputWithCrossAttentions(
  1359. loss=loss,
  1360. logits=logits,
  1361. past_key_values=outputs.past_key_values,
  1362. hidden_states=outputs.hidden_states,
  1363. attentions=outputs.attentions,
  1364. cross_attentions=outputs.cross_attentions,
  1365. )
  1366. __all__ = [
  1367. "MvpForCausalLM",
  1368. "MvpForConditionalGeneration",
  1369. "MvpForQuestionAnswering",
  1370. "MvpForSequenceClassification",
  1371. "MvpModel",
  1372. "MvpPreTrainedModel",
  1373. ]