modeling_prophetnet.py 84 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846
  1. # Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ProphetNet model, ported from ProphetNet repo(fairsequery_states version)."""
  15. import copy
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import Tensor, nn
  20. from torch.nn import LayerNorm
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import ModelOutput, auto_docstring, logging
  28. from .configuration_prophetnet import ProphetNetConfig
  29. logger = logging.get_logger(__name__)
  30. def softmax(hidden_state, dim, onnx_trace=False):
  31. if onnx_trace:
  32. return nn.functional.softmax(hidden_state.float(), dim=dim)
  33. else:
  34. return nn.functional.softmax(hidden_state, dim=dim, dtype=torch.float32)
  35. def ngram_attention_bias(sequence_length, ngram, device, dtype):
  36. """
  37. This function computes the bias for the predict stream
  38. """
  39. left_block = (
  40. torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min
  41. )
  42. right_block = left_block.detach().clone()
  43. # create bias
  44. for stream_idx in range(ngram):
  45. right_block[stream_idx].fill_diagonal_(0, wrap=False)
  46. left_block[stream_idx].triu_(-stream_idx + 1)
  47. left_block[:, :, 0] = 0
  48. return torch.cat([left_block, right_block], dim=2)
  49. def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
  50. """
  51. This function computes individual parts of the relative position buckets. For more detail, see paper.
  52. """
  53. inv_relative_positions = -relative_positions
  54. rel_positions_bucket = 0
  55. if is_bidirectional:
  56. num_buckets = num_buckets // 2
  57. rel_positions_bucket = (
  58. rel_positions_bucket
  59. + torch.lt(inv_relative_positions, torch.zeros_like(inv_relative_positions)).int() * num_buckets
  60. )
  61. inv_relative_positions = torch.abs(inv_relative_positions)
  62. else:
  63. inv_relative_positions = torch.max(inv_relative_positions, torch.zeros_like(inv_relative_positions))
  64. max_exact = num_buckets // 2
  65. is_small = torch.lt(inv_relative_positions, max_exact)
  66. val_if_large = max_exact + torch.log(inv_relative_positions.float() / max_exact) / math.log(
  67. max_distance / max_exact
  68. ) * (num_buckets - max_exact)
  69. val_if_large = torch.min(val_if_large, torch.ones_like(val_if_large) * (num_buckets - 1)).int()
  70. rel_positions_bucket = rel_positions_bucket + torch.where(is_small, inv_relative_positions.int(), val_if_large)
  71. return rel_positions_bucket
  72. def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids):
  73. """
  74. This function computes both main and predict relative position buckets. For more detail, see paper.
  75. """
  76. # main stream
  77. main_stream_relative_positions = position_ids.unsqueeze(1).repeat(1, position_ids.size(-1), 1)
  78. main_stream_relative_positions = main_stream_relative_positions - position_ids.unsqueeze(-1)
  79. # predicting stream
  80. predicting_stream_relative_positions = torch.cat((position_ids - 1, position_ids), dim=-1).unsqueeze(1)
  81. predicting_stream_relative_positions = predicting_stream_relative_positions.repeat(1, position_ids.size(-1), 1)
  82. predicting_stream_relative_positions = predicting_stream_relative_positions - position_ids.unsqueeze(-1)
  83. # get both position buckets
  84. main_relative_position_buckets = compute_relative_buckets(
  85. num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False
  86. )
  87. predict_relative_position_buckets = compute_relative_buckets(
  88. num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False
  89. )
  90. return main_relative_position_buckets, predict_relative_position_buckets
  91. @dataclass
  92. @auto_docstring(
  93. custom_intro="""
  94. Base class for sequence-to-sequence language models outputs.
  95. """
  96. )
  97. class ProphetNetSeq2SeqLMOutput(ModelOutput):
  98. r"""
  99. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  100. Language modeling loss.
  101. logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
  102. Prediction scores of the main stream language modeling head (scores for each vocabulary token before
  103. SoftMax).
  104. logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
  105. Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
  106. SoftMax).
  107. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  108. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  109. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  110. used (see `past_key_values` input) to speed up sequential decoding.
  111. decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  112. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  113. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  114. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  115. outputs.
  116. decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  117. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  118. decoder_sequence_length, decoder_sequence_length)`.
  119. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  120. weighted average in the self-attention heads.
  121. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  122. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  123. """
  124. loss: torch.FloatTensor | None = None
  125. logits: torch.FloatTensor | None = None
  126. logits_ngram: torch.FloatTensor | None = None
  127. past_key_values: Cache | None = None
  128. decoder_hidden_states: tuple[torch.FloatTensor] | None = None
  129. decoder_ngram_hidden_states: tuple[torch.FloatTensor] | None = None
  130. decoder_attentions: tuple[torch.FloatTensor] | None = None
  131. decoder_ngram_attentions: tuple[torch.FloatTensor] | None = None
  132. cross_attentions: tuple[torch.FloatTensor] | None = None
  133. encoder_last_hidden_state: torch.FloatTensor | None = None
  134. encoder_hidden_states: tuple[torch.FloatTensor] | None = None
  135. encoder_attentions: tuple[torch.FloatTensor] | None = None
  136. @dataclass
  137. @auto_docstring(
  138. custom_intro="""
  139. Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
  140. decoding.
  141. """
  142. )
  143. class ProphetNetSeq2SeqModelOutput(ModelOutput):
  144. r"""
  145. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
  146. Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.
  147. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  148. hidden_size)` is output.
  149. last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
  150. Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
  151. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  152. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  153. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  154. used (see `past_key_values` input) to speed up sequential decoding.
  155. decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  156. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  157. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  158. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  159. outputs.
  160. decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  161. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  162. decoder_sequence_length, decoder_sequence_length)`.
  163. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  164. weighted average in the
  165. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  166. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  167. """
  168. last_hidden_state: torch.FloatTensor
  169. last_hidden_state_ngram: torch.FloatTensor | None = None
  170. past_key_values: Cache | None = None
  171. decoder_hidden_states: tuple[torch.FloatTensor] | None = None
  172. decoder_ngram_hidden_states: tuple[torch.FloatTensor] | None = None
  173. decoder_attentions: tuple[torch.FloatTensor] | None = None
  174. decoder_ngram_attentions: tuple[torch.FloatTensor] | None = None
  175. cross_attentions: tuple[torch.FloatTensor] | None = None
  176. encoder_last_hidden_state: torch.FloatTensor | None = None
  177. encoder_hidden_states: tuple[torch.FloatTensor] | None = None
  178. encoder_attentions: tuple[torch.FloatTensor] | None = None
  179. @dataclass
  180. @auto_docstring(
  181. custom_intro="""
  182. Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
  183. """
  184. )
  185. class ProphetNetDecoderModelOutput(ModelOutput):
  186. r"""
  187. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
  188. Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.
  189. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  190. hidden_size)` is output.
  191. last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
  192. Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
  193. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  194. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  195. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  196. used (see `past_key_values` input) to speed up sequential decoding.
  197. hidden_states_ngram (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  198. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  199. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  200. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  201. outputs.
  202. ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  203. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  204. decoder_sequence_length, decoder_sequence_length)`.
  205. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  206. weighted average in the
  207. """
  208. last_hidden_state: torch.FloatTensor
  209. last_hidden_state_ngram: torch.FloatTensor | None = None
  210. past_key_values: Cache | None = None
  211. hidden_states: tuple[torch.FloatTensor] | None = None
  212. hidden_states_ngram: tuple[torch.FloatTensor] | None = None
  213. attentions: tuple[torch.FloatTensor] | None = None
  214. ngram_attentions: tuple[torch.FloatTensor] | None = None
  215. cross_attentions: tuple[torch.FloatTensor] | None = None
  216. @dataclass
  217. @auto_docstring(
  218. custom_intro="""
  219. Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
  220. """
  221. )
  222. class ProphetNetDecoderLMOutput(ModelOutput):
  223. r"""
  224. ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  225. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  226. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  227. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  228. outputs.
  229. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  230. Language modeling loss.
  231. logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
  232. Prediction scores of the main stream language modeling head (scores for each vocabulary token before
  233. SoftMax).
  234. logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
  235. Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
  236. SoftMax).
  237. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  238. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  239. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  240. used (see `past_key_values` input) to speed up sequential decoding.
  241. hidden_states_ngram (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  242. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  243. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  244. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  245. outputs.
  246. ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  247. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  248. decoder_sequence_length, decoder_sequence_length)`.
  249. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  250. weighted average in the
  251. """
  252. loss: torch.FloatTensor | None = None
  253. logits: torch.FloatTensor | None = None
  254. logits_ngram: torch.FloatTensor | None = None
  255. past_key_values: Cache | None = None
  256. hidden_states: tuple[torch.FloatTensor] | None = None
  257. hidden_states_ngram: tuple[torch.FloatTensor] | None = None
  258. attentions: tuple[torch.FloatTensor] | None = None
  259. ngram_attentions: tuple[torch.FloatTensor] | None = None
  260. cross_attentions: tuple[torch.FloatTensor] | None = None
  261. @auto_docstring
  262. class ProphetNetPreTrainedModel(PreTrainedModel):
  263. config: ProphetNetConfig
  264. base_model_prefix = "prophetnet"
  265. supports_gradient_checkpointing = True
  266. def _shift_right(self, input_ids):
  267. decoder_start_token_id = self.config.decoder_start_token_id
  268. pad_token_id = self.config.pad_token_id
  269. assert decoder_start_token_id is not None, (
  270. "self.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the"
  271. " pad_token_id. See ProphetNet docs for more information"
  272. )
  273. # shift inputs to the right
  274. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  275. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  276. shifted_input_ids[..., 0] = decoder_start_token_id
  277. assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
  278. # replace possible -100 values in labels by `pad_token_id`
  279. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  280. assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
  281. return shifted_input_ids
  282. class ProphetNetPositionalEmbeddings(nn.Embedding):
  283. """
  284. This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
  285. based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
  286. the forward function.
  287. """
  288. def __init__(self, config: ProphetNetConfig) -> None:
  289. self.max_length = config.max_position_embeddings
  290. super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
  291. def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None):
  292. assert (position_ids is None) or (self.padding_idx is None), (
  293. "If position_ids is pre-computed then padding_idx should not be set."
  294. )
  295. if position_ids is None:
  296. if past_key_values is not None and past_key_values.get_seq_length() != 0:
  297. # position_ids is the same for every token when decoding a single step
  298. # Without the int() cast, it doesn't work in some cases when exporting to ONNX
  299. prev_num_input_ids = past_key_values.get_seq_length()
  300. num_input_ids = inputs_shape[1] + prev_num_input_ids
  301. position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * (
  302. int(self.padding_idx + num_input_ids)
  303. )
  304. else:
  305. if attention_mask is None:
  306. attention_mask = torch.ones(inputs_shape, dtype=torch.long, device=device)
  307. # retrieve position_ids from input_ids / attention_mask
  308. position_ids = (
  309. torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
  310. ).long() + self.padding_idx
  311. # make sure position_ids are not bigger then max_length
  312. position_ids = position_ids.clamp(0, self.max_length - 1)
  313. return super().forward(position_ids), position_ids
  314. def _forward(self, position_ids):
  315. return super().forward(position_ids)
  316. class ProphetNetAttention(nn.Module):
  317. """Multi-headed attention from 'Attention Is All You Need' paper"""
  318. def __init__(self, config: ProphetNetConfig, num_attn_heads: int, layer_idx: int | None = None):
  319. super().__init__()
  320. hidden_size = config.hidden_size
  321. self.attention_dropout = config.attention_dropout
  322. self.dropout = config.dropout
  323. self.num_attn_heads = num_attn_heads
  324. self.head_dim = hidden_size // num_attn_heads
  325. self.layer_idx = layer_idx
  326. assert self.head_dim * num_attn_heads == hidden_size, (
  327. "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and"
  328. " `config.num_decoder_attention_heads`"
  329. )
  330. self.key_proj = nn.Linear(hidden_size, hidden_size)
  331. self.value_proj = nn.Linear(hidden_size, hidden_size)
  332. self.query_proj = nn.Linear(hidden_size, hidden_size)
  333. self.out_proj = nn.Linear(hidden_size, hidden_size)
  334. def forward(
  335. self,
  336. hidden_states,
  337. key_value_states: Tensor | None = None,
  338. attention_mask: Tensor | None = None,
  339. past_key_values: Cache | None = None,
  340. output_attentions: bool | None = False,
  341. **kwargs,
  342. ) -> tuple[Tensor, Tensor | None]:
  343. batch_size, tgt_len, hidden_size = hidden_states.size()
  344. # if key_value_states are provided this layer is used as a cross-attention layer
  345. # for the decoder
  346. is_cross_attention = key_value_states is not None
  347. assert list(hidden_states.size()) == [
  348. batch_size,
  349. tgt_len,
  350. hidden_size,
  351. ], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}"
  352. # previous time steps are cached - no need to recompute key and value if they are static
  353. query_states = self.query_proj(hidden_states) / (self.head_dim**0.5)
  354. is_updated = False
  355. if past_key_values is not None:
  356. if isinstance(past_key_values, EncoderDecoderCache):
  357. is_updated = past_key_values.is_updated.get(self.layer_idx)
  358. if is_cross_attention:
  359. # after the first generated id, we can subsequently re-use all key/value_states from cache
  360. curr_past_key_values = past_key_values.cross_attention_cache
  361. else:
  362. curr_past_key_values = past_key_values.self_attention_cache
  363. else:
  364. curr_past_key_values = past_key_values
  365. current_states = key_value_states if is_cross_attention else hidden_states
  366. if is_cross_attention and past_key_values is not None and is_updated:
  367. # reuse k,v, cross_attentions
  368. key_states = curr_past_key_values.layers[self.layer_idx].keys
  369. value_states = curr_past_key_values.layers[self.layer_idx].values
  370. else:
  371. key_states = self.key_proj(current_states)
  372. value_states = self.value_proj(current_states)
  373. key_states = key_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2)
  374. value_states = value_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2)
  375. if past_key_values is not None:
  376. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  377. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  378. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  379. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  380. past_key_values.is_updated[self.layer_idx] = True
  381. query_states = query_states.view(batch_size, tgt_len, self.num_attn_heads, self.head_dim).transpose(1, 2)
  382. src_len = key_states.size(2)
  383. attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
  384. expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
  385. if attn_weights.size() != expected_shape:
  386. raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
  387. # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
  388. if attention_mask is not None and attention_mask.dim() == 0:
  389. attention_mask = None
  390. expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
  391. if attention_mask is not None and attention_mask.size() != expected_shape:
  392. raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
  393. if attention_mask is not None: # don't attend to padding symbols
  394. attn_weights = attn_weights + attention_mask
  395. if output_attentions:
  396. attn_weights_reshaped = attn_weights
  397. else:
  398. attn_weights_reshaped = None
  399. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  400. attn_probs = nn.functional.dropout(
  401. attn_weights,
  402. p=self.attention_dropout,
  403. training=self.training,
  404. )
  405. attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
  406. expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
  407. if attn_output.size() != expected_shape:
  408. raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
  409. attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
  410. attn_output = self.out_proj(attn_output)
  411. attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
  412. return attn_output, attn_weights_reshaped
  413. class ProphetNetFeedForward(nn.Module):
  414. """
  415. This is the residual two feed-forward layer block based on the original Transformer implementation.
  416. """
  417. def __init__(self, config: ProphetNetConfig, ffn_dim: int):
  418. super().__init__()
  419. self.activation_fn = ACT2FN[config.activation_function]
  420. self.intermediate = nn.Linear(config.hidden_size, ffn_dim)
  421. self.output = nn.Linear(ffn_dim, config.hidden_size)
  422. self.activation_dropout = config.activation_dropout
  423. self.dropout = config.dropout
  424. def forward(self, hidden_states):
  425. hidden_states = self.intermediate(hidden_states)
  426. hidden_states = self.activation_fn(hidden_states)
  427. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  428. hidden_states = self.output(hidden_states)
  429. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  430. return hidden_states
  431. class ProphetNetNgramSelfAttention(nn.Module):
  432. def __init__(self, config: ProphetNetConfig, layer_idx=None):
  433. super().__init__()
  434. self.hidden_size = config.hidden_size
  435. self.num_buckets = config.num_buckets
  436. self.relative_max_distance = config.relative_max_distance
  437. self.num_attn_heads = config.num_decoder_attention_heads
  438. self.dropout = config.dropout
  439. self.attention_dropout = config.attention_dropout
  440. self.head_dim = config.hidden_size // self.num_attn_heads
  441. self.ngram = config.ngram
  442. self.layer_idx = layer_idx
  443. assert self.head_dim * self.num_attn_heads == config.hidden_size, (
  444. "config.hidden_size must be divisible by num_attn_heads"
  445. )
  446. # key, value, query projection
  447. self.key_proj = nn.Linear(config.hidden_size, config.hidden_size)
  448. self.value_proj = nn.Linear(config.hidden_size, config.hidden_size)
  449. self.query_proj = nn.Linear(config.hidden_size, config.hidden_size)
  450. # out projection
  451. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
  452. # rel position embeddings
  453. self.relative_pos_embeddings = nn.Linear(config.hidden_size, self.num_buckets * self.num_attn_heads)
  454. # for onnx runtime
  455. self.onnx_trace = False
  456. def _shape(self, tensor, seq_len, batch_size):
  457. return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
  458. def prepare_for_onnx_export_(self):
  459. self.onnx_trace = True
  460. def forward(
  461. self,
  462. hidden_states,
  463. past_key_values: Cache | None = None,
  464. attention_mask=None,
  465. extended_predict_attention_mask=None,
  466. main_relative_position_buckets=None,
  467. predict_relative_position_buckets=None,
  468. position_ids=None,
  469. **kwargs,
  470. ):
  471. batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
  472. assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
  473. f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
  474. f" {hidden_states.shape}"
  475. )
  476. # project
  477. query_states = self.query_proj(hidden_states)
  478. key_states = self.key_proj(hidden_states)
  479. value_states = self.value_proj(hidden_states)
  480. # normalize
  481. query_states = query_states / (self.head_dim**0.5)
  482. # reshape
  483. query_states = self._shape(query_states, ngram_sequence_length, batch_size)
  484. key_states = self._shape(key_states, -1, batch_size)
  485. value_states = self._shape(value_states, -1, batch_size)
  486. proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
  487. query_states = query_states.reshape(*proj_shape)
  488. key_states = key_states.reshape(*proj_shape)
  489. value_states = value_states.reshape(*proj_shape)
  490. # chunk into main stream and predict stream
  491. hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
  492. query_states_list = query_states.chunk(1 + self.ngram, dim=2)
  493. key_states_list = key_states.chunk(1 + self.ngram, dim=2)
  494. value_states_list = value_states.chunk(1 + self.ngram, dim=2)
  495. main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
  496. main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
  497. main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:]
  498. main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:]
  499. # ProphetNet has two separate attention layers, one for self and one for cross attention
  500. # We need to obtain the self attention only for this module, if `EncoderDecoderCache`
  501. if past_key_values is not None:
  502. if isinstance(past_key_values, EncoderDecoderCache):
  503. curr_past_key_values = past_key_values.self_attention_cache
  504. else:
  505. curr_past_key_values = past_key_values
  506. main_key_states, main_value_states = curr_past_key_values.update(
  507. main_key_states, main_value_states, self.layer_idx
  508. )
  509. # get seq_length of main stream only
  510. sequence_length = ngram_sequence_length // (1 + self.ngram)
  511. # MAIN-STREAM
  512. # main attn weights
  513. # [batch_size, number_heads, sequence_length, head_dimesion]
  514. # x [batch_size, number_heads, head_dimesion, sequence_length]
  515. # -> [batch_size, number_heads, sequence_length, sequence_length]
  516. main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
  517. # retrieve relative position embeddings for each layer -> see paper for more details
  518. main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
  519. main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
  520. )
  521. main_attn_weights = main_attn_weights + main_relative_pos_embeddings
  522. if attention_mask is not None:
  523. main_attn_weights = main_attn_weights + attention_mask
  524. main_attn_probs = softmax(
  525. main_attn_weights,
  526. dim=-1,
  527. onnx_trace=self.onnx_trace,
  528. ).type_as(main_attn_weights)
  529. main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
  530. # project to attn_output
  531. # [batch_size, number_heads, sequence_length, sequence_length]
  532. # x [batch_size, number_heads, sequence_length, head_dimesion]
  533. # -> [batch_size, number_heads, sequence_length, head_dimesion]
  534. main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
  535. # reshape so that num_heads dim is merged into last `head_dim` axis
  536. main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
  537. main_attn_output = self.out_proj(main_attn_output)
  538. # PREDICT-STREAM
  539. # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
  540. predict_query_states = torch.stack(predict_query_states_list, 1).view(
  541. batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
  542. )
  543. # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
  544. predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
  545. # [batch_size, sequence_length, ngram, hidden_size]
  546. predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
  547. # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
  548. predict_value_states = torch.cat(
  549. [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
  550. )
  551. # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
  552. # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
  553. # -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  554. predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
  555. # retrieve relative position embeddings for each layer -> see paper for more details
  556. # [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
  557. predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
  558. predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
  559. )
  560. # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  561. predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
  562. if extended_predict_attention_mask is not None:
  563. # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  564. extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
  565. extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
  566. predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
  567. predict_attn_probs = softmax(
  568. predict_attn_weights,
  569. dim=-1,
  570. onnx_trace=self.onnx_trace,
  571. ).type_as(predict_attn_weights)
  572. predict_attn_probs = nn.functional.dropout(
  573. predict_attn_probs, p=self.attention_dropout, training=self.training
  574. )
  575. # project to attention output
  576. # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  577. # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
  578. # -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
  579. predict_attn_output = torch.einsum(
  580. "bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
  581. )
  582. # reshape so that num_heads dim is merged into last `head_dim` axis
  583. # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
  584. predict_attn_output = predict_attn_output.transpose(2, 3)
  585. predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
  586. predict_attn_output = self.out_proj(predict_attn_output)
  587. # concat to single attn output
  588. # [batch_size, (1+ngram)*sequence_length, hidden_size]
  589. attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
  590. # reshape into better form for `config.output_attentions`
  591. main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
  592. attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
  593. return attn_output, main_attn_probs, predict_attn_probs
  594. def get_main_relative_pos_embeddings(
  595. self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
  596. ):
  597. # input hidden_states [batch_size, sequence_length, hidden_size]
  598. # input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
  599. # input position_ids [batch_size, sequence_length] or [1,1]
  600. batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
  601. attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
  602. if main_relative_position_buckets is None:
  603. batch_size, sequence_length = hidden_states.shape[:2]
  604. relative_positions = (
  605. torch.arange(1, attn_weights.shape[-1] + 1)
  606. .unsqueeze(0)
  607. .unsqueeze(0)
  608. .repeat(batch_size, sequence_length, 1)
  609. .to(position_ids.device)
  610. )
  611. # [batch_size, sequence_length, sequence_length+1]
  612. relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
  613. main_relative_position_buckets = compute_relative_buckets(
  614. self.num_buckets, self.relative_max_distance, relative_positions, False
  615. )
  616. # [batch_size, sequence_length, num_buckets * num_heads]
  617. rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
  618. rel_pos_embeddings = rel_pos_embeddings.view(
  619. rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
  620. )
  621. rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
  622. # [batch_size, num_heads, sequence_length, num_buckets]
  623. rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
  624. main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
  625. # [batch_size * num_heads * sequence_length, sequence_length]
  626. main_relative_position_buckets = main_relative_position_buckets.view(
  627. -1, main_relative_position_buckets.shape[-1]
  628. )
  629. main_relative_position_buckets = main_relative_position_buckets.long()
  630. # [batch_size * num_heads * sequence_length, sequence_length]
  631. rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
  632. main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
  633. main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
  634. return main_relative_pos_embeddings
  635. def get_predict_relative_pos_embeddings(
  636. self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
  637. ):
  638. # input hidden_states [batch_size, sequence_length, ngram, hidden_size]
  639. # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
  640. # input position_ids [batch_size, sequence_length] or [1,1]
  641. # input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
  642. batch_size, sequence_length = hidden_states.shape[0:2]
  643. if predict_relative_position_buckets is None:
  644. key_sequence_length = attn_weights.shape[-1]
  645. assert position_ids[0][0] == key_sequence_length - 1, (
  646. "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)"
  647. )
  648. relative_positions = (
  649. torch.arange(0, key_sequence_length)
  650. .unsqueeze(0)
  651. .unsqueeze(0)
  652. .repeat(batch_size, sequence_length, 1)
  653. .to(position_ids.device)
  654. )
  655. relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
  656. predict_relative_position_buckets = compute_relative_buckets(
  657. self.num_buckets, self.relative_max_distance, relative_positions, False
  658. )
  659. # [batch_size, ngram, sequence_length, hidden_size]
  660. hidden_states = hidden_states.transpose(1, 2)
  661. rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
  662. # [batch_size, ngram, sequence_length, num_buckets, num_heads]
  663. rel_pos_embeddings = rel_pos_embeddings.view(
  664. hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
  665. )
  666. rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
  667. # [batch_size * ngram * sequence_length * num_heads, num_buckets]
  668. rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
  669. # [ngram, batch_size, num_heads * sequence_length, -1]
  670. predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
  671. predict_relative_position_buckets = predict_relative_position_buckets.repeat(
  672. self.ngram, 1, self.num_attn_heads, 1
  673. )
  674. # [ngram * batch_size * num_heads * sequence_length, -1]
  675. predict_relative_position_buckets = predict_relative_position_buckets.view(
  676. -1, predict_relative_position_buckets.size(-1)
  677. ).long()
  678. predict_relative_pos_embeddings = torch.gather(
  679. rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
  680. )
  681. # [batch_size, gram, num_heads, sequence_length, -1]
  682. predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
  683. batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
  684. )
  685. return predict_relative_pos_embeddings
  686. class ProphetNetEncoderLayer(GradientCheckpointingLayer):
  687. """
  688. Encoder block for Prophetnet
  689. """
  690. def __init__(self, config: ProphetNetConfig):
  691. super().__init__()
  692. # 1st residual block
  693. self.self_attn = ProphetNetAttention(config, config.num_encoder_attention_heads)
  694. self.self_attn_layer_norm = LayerNorm(config.hidden_size)
  695. # 2nd residual block
  696. self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
  697. self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
  698. def forward(
  699. self,
  700. hidden_states,
  701. attention_mask,
  702. output_attentions: bool = False,
  703. ):
  704. # 1st residual block
  705. attention_output, attn_weights = self.self_attn(
  706. hidden_states=hidden_states,
  707. attention_mask=attention_mask,
  708. output_attentions=output_attentions,
  709. )
  710. hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
  711. # 2nd residual block
  712. feed_forward_output = self.feed_forward(hidden_states)
  713. hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
  714. outputs = (hidden_states,)
  715. if output_attentions:
  716. outputs += (attn_weights,)
  717. return outputs
  718. class ProphetNetDecoderLayer(GradientCheckpointingLayer):
  719. """
  720. Decoder block for Prophetnet
  721. """
  722. def __init__(self, config: ProphetNetConfig, layer_idx=None):
  723. super().__init__()
  724. # 1st residual block
  725. self.self_attn = ProphetNetNgramSelfAttention(config, layer_idx=layer_idx)
  726. self.self_attn_layer_norm = LayerNorm(config.hidden_size)
  727. # 2nd residual block
  728. if config.add_cross_attention:
  729. self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads, layer_idx=layer_idx)
  730. self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
  731. # 3rd residual block
  732. self.feed_forward = ProphetNetFeedForward(config, config.decoder_ffn_dim)
  733. self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
  734. def forward(
  735. self,
  736. hidden_states,
  737. attention_mask=None,
  738. encoder_hidden_states=None,
  739. encoder_attn_mask=None,
  740. extended_predict_attention_mask=None,
  741. main_relative_position_buckets=None,
  742. predict_relative_position_buckets=None,
  743. position_ids=None,
  744. past_key_values=None,
  745. use_cache: bool | None = True,
  746. output_attentions: bool | None = False,
  747. **kwargs,
  748. ):
  749. # 1st residual block
  750. ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn(
  751. hidden_states=hidden_states,
  752. past_key_values=past_key_values,
  753. attention_mask=attention_mask,
  754. extended_predict_attention_mask=extended_predict_attention_mask,
  755. main_relative_position_buckets=main_relative_position_buckets,
  756. predict_relative_position_buckets=predict_relative_position_buckets,
  757. position_ids=position_ids,
  758. )
  759. hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output)
  760. cross_attn_weights = None
  761. if encoder_hidden_states is not None:
  762. # 2nd residual block
  763. attention_output, cross_attn_weights = self.cross_attn(
  764. hidden_states=hidden_states,
  765. key_value_states=encoder_hidden_states,
  766. attention_mask=encoder_attn_mask,
  767. past_key_values=past_key_values,
  768. output_attentions=output_attentions,
  769. )
  770. hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states)
  771. # 3rd residual block
  772. feed_forward_output = self.feed_forward(hidden_states)
  773. hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
  774. outputs = (hidden_states,)
  775. if output_attentions:
  776. outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights)
  777. return outputs
  778. @auto_docstring(
  779. custom_intro="""
  780. The standalone encoder part of the ProphetNetModel.
  781. """
  782. )
  783. class ProphetNetEncoder(ProphetNetPreTrainedModel):
  784. def __init__(self, config: ProphetNetConfig):
  785. super().__init__(config)
  786. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  787. self.position_embeddings = ProphetNetPositionalEmbeddings(config)
  788. self.embeddings_layer_norm = LayerNorm(config.hidden_size)
  789. self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
  790. self.gradient_checkpointing = False
  791. # Initialize weights and apply final processing
  792. self.post_init()
  793. def get_input_embeddings(self):
  794. return self.word_embeddings
  795. def set_input_embeddings(self, value):
  796. self.word_embeddings = value
  797. @auto_docstring
  798. def forward(
  799. self,
  800. input_ids: torch.Tensor | None = None,
  801. attention_mask: torch.Tensor | None = None,
  802. inputs_embeds: torch.Tensor | None = None,
  803. output_attentions: bool | None = None,
  804. output_hidden_states: bool | None = None,
  805. return_dict: bool | None = None,
  806. **kwargs,
  807. ) -> tuple | BaseModelOutput:
  808. r"""
  809. Example:
  810. ```python
  811. >>> from transformers import AutoTokenizer, ProphetNetEncoder
  812. >>> import torch
  813. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  814. >>> model = ProphetNetEncoder.from_pretrained("patrickvonplaten/prophetnet-large-uncased-standalone")
  815. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  816. >>> outputs = model(**inputs)
  817. >>> last_hidden_states = outputs.last_hidden_state
  818. ```"""
  819. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  820. output_hidden_states = (
  821. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  822. )
  823. return_dict = return_dict if return_dict is not None else self.config.return_dict
  824. if input_ids is None and inputs_embeds is None:
  825. raise ValueError("Either input_ids or inputs_embeds has to be passed.")
  826. elif input_ids is not None and inputs_embeds is not None:
  827. raise ValueError("Make sure to only pass input_ids or inputs_embeds.")
  828. elif input_ids is not None and inputs_embeds is None:
  829. inputs_embeds = self.word_embeddings(input_ids)
  830. # prepare attention mask
  831. if attention_mask is not None:
  832. extended_attention_mask = (
  833. 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
  834. ) * torch.finfo(self.dtype).min
  835. extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
  836. else:
  837. extended_attention_mask = None
  838. position_embeddings, position_ids = self.position_embeddings(inputs_embeds.shape[:2], inputs_embeds.device)
  839. hidden_states = inputs_embeds + position_embeddings
  840. hidden_states = self.embeddings_layer_norm(hidden_states)
  841. hidden_states = nn.functional.dropout(hidden_states, p=self.config.dropout, training=self.training)
  842. encoder_hidden_states = () if output_hidden_states else None
  843. all_attentions = () if output_attentions else None
  844. for idx, encoder_layer in enumerate(self.layers):
  845. if output_hidden_states:
  846. encoder_hidden_states = encoder_hidden_states + (hidden_states,)
  847. layer_outputs = encoder_layer(
  848. hidden_states,
  849. attention_mask=extended_attention_mask,
  850. output_attentions=output_attentions,
  851. )
  852. hidden_states = layer_outputs[0]
  853. if output_attentions:
  854. all_attentions = all_attentions + (layer_outputs[1],)
  855. if output_hidden_states:
  856. encoder_hidden_states = encoder_hidden_states + (hidden_states,)
  857. if not return_dict:
  858. return tuple(v for v in [hidden_states, encoder_hidden_states, all_attentions] if v is not None)
  859. return BaseModelOutput(
  860. last_hidden_state=hidden_states, hidden_states=encoder_hidden_states, attentions=all_attentions
  861. )
  862. @auto_docstring(
  863. custom_intro="""
  864. The standalone decoder part of the ProphetNetModel.
  865. """
  866. )
  867. class ProphetNetDecoder(ProphetNetPreTrainedModel):
  868. def __init__(self, config: ProphetNetConfig):
  869. super().__init__(config)
  870. self.ngram = config.ngram
  871. self.num_buckets = config.num_buckets
  872. self.relative_max_distance = config.relative_max_distance
  873. self.dropout = config.dropout
  874. self.max_target_positions = config.max_position_embeddings
  875. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  876. self.position_embeddings = ProphetNetPositionalEmbeddings(config)
  877. self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
  878. self.layers = nn.ModuleList(
  879. [ProphetNetDecoderLayer(config, layer_idx=i) for i in range(config.num_decoder_layers)]
  880. )
  881. self.embeddings_layer_norm = LayerNorm(config.hidden_size)
  882. self.gradient_checkpointing = False
  883. # Initialize weights and apply final processing
  884. self.post_init()
  885. def get_input_embeddings(self):
  886. return self.word_embeddings
  887. def set_input_embeddings(self, value):
  888. self.word_embeddings = value
  889. @auto_docstring
  890. def forward(
  891. self,
  892. input_ids: torch.Tensor | None = None,
  893. attention_mask: torch.Tensor | None = None,
  894. encoder_hidden_states: torch.Tensor | None = None,
  895. encoder_attention_mask: torch.Tensor | None = None,
  896. past_key_values: Cache | None = None,
  897. inputs_embeds: torch.Tensor | None = None,
  898. use_cache: bool | None = None,
  899. output_attentions: bool | None = None,
  900. output_hidden_states: bool | None = None,
  901. return_dict: bool | None = None,
  902. **kwargs,
  903. ) -> tuple | ProphetNetDecoderModelOutput:
  904. r"""
  905. Example:
  906. ```python
  907. >>> from transformers import AutoTokenizer, ProphetNetDecoder
  908. >>> import torch
  909. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  910. >>> model = ProphetNetDecoder.from_pretrained("microsoft/prophetnet-large-uncased", add_cross_attention=False)
  911. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  912. >>> outputs = model(**inputs)
  913. >>> last_hidden_states = outputs.last_hidden_state
  914. ```"""
  915. use_cache = use_cache if use_cache is not None else self.config.use_cache
  916. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  917. output_hidden_states = (
  918. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  919. )
  920. return_dict = return_dict if return_dict is not None else self.config.return_dict
  921. if input_ids is None and inputs_embeds is None:
  922. raise ValueError("Either `decoder_input_ids` or `decoder_inputs_embeds` has to be passed.")
  923. elif input_ids is not None and inputs_embeds is not None:
  924. raise ValueError("Make sure to only pass `decoder_input_ids` or `decoder_inputs_embeds`.")
  925. elif input_ids is not None and inputs_embeds is None:
  926. inputs_embeds = self.word_embeddings(input_ids)
  927. batch_size, sequence_length = inputs_embeds.shape[:2]
  928. if self.gradient_checkpointing and self.training:
  929. if use_cache:
  930. logger.warning_once(
  931. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  932. )
  933. use_cache = False
  934. if use_cache and past_key_values is None:
  935. past_key_values = (
  936. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  937. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  938. else DynamicCache(config=self.config)
  939. )
  940. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  941. main_stream_pos_embed, position_ids = self.position_embeddings(
  942. (batch_size, sequence_length),
  943. device=inputs_embeds.device,
  944. past_key_values=past_key_values,
  945. )
  946. if past_key_values_length != 0:
  947. main_relative_position_buckets, predict_relative_position_buckets = None, None
  948. else:
  949. (
  950. main_relative_position_buckets,
  951. predict_relative_position_buckets,
  952. ) = self.compute_buffered_relative_buckets(position_ids)
  953. predicting_stream_pos_embed = self.position_embeddings._forward(position_ids + 1)
  954. # add position embeddings
  955. hidden_states = inputs_embeds + main_stream_pos_embed
  956. ngram_embeddings = self.ngram_embeddings.weight
  957. # prepare attention mask
  958. if past_key_values_length != 0:
  959. assert hidden_states.size(1) == 1, (
  960. "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1"
  961. )
  962. ngram_hidden_states = [
  963. (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1)
  964. for ngram in range(self.ngram)
  965. ]
  966. extended_attention_mask = None
  967. extended_predict_attention_mask = None
  968. else:
  969. ngram_hidden_states = [
  970. (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram)
  971. ]
  972. extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask)
  973. extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask)
  974. # prepare encoder attention mask
  975. if encoder_attention_mask is not None:
  976. extended_encoder_attention_mask = (
  977. 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
  978. ) * torch.finfo(self.dtype).min
  979. extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
  980. else:
  981. extended_encoder_attention_mask = None
  982. hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1)
  983. if self.embeddings_layer_norm:
  984. hidden_states = self.embeddings_layer_norm(hidden_states)
  985. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  986. # init attentions, hidden_states and cache with empty tuples
  987. all_main_stream_hidden_states = () if output_hidden_states else None
  988. all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None
  989. all_main_stream_attns = () if output_attentions else None
  990. all_ngram_stream_attns = () if output_attentions else None
  991. all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
  992. for idx, decoder_layer in enumerate(self.layers):
  993. if output_hidden_states:
  994. # grad cannot be kept because tensor is sliced
  995. all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
  996. if self.config.ngram > 0:
  997. all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
  998. layer_outputs = decoder_layer(
  999. hidden_states,
  1000. extended_attention_mask,
  1001. encoder_hidden_states, # as a positional argument for gradient checkpointing
  1002. encoder_attn_mask=extended_encoder_attention_mask,
  1003. extended_predict_attention_mask=extended_predict_attention_mask,
  1004. main_relative_position_buckets=main_relative_position_buckets,
  1005. predict_relative_position_buckets=predict_relative_position_buckets,
  1006. position_ids=position_ids,
  1007. past_key_values=past_key_values,
  1008. use_cache=use_cache,
  1009. output_attentions=output_attentions,
  1010. )
  1011. hidden_states = layer_outputs[0]
  1012. if output_attentions:
  1013. all_main_stream_attns += (layer_outputs[1],)
  1014. all_ngram_stream_attns += (layer_outputs[2],)
  1015. if self.config.add_cross_attention:
  1016. all_cross_attns += (layer_outputs[3],)
  1017. if output_hidden_states:
  1018. all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
  1019. if self.config.ngram > 0:
  1020. all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
  1021. # split last_hidden_state for return
  1022. last_hidden_state = hidden_states[:, :sequence_length]
  1023. last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None
  1024. if not return_dict:
  1025. return tuple(
  1026. v
  1027. for v in [
  1028. last_hidden_state,
  1029. last_hidden_state_ngram,
  1030. past_key_values,
  1031. all_main_stream_hidden_states,
  1032. all_ngram_stream_hidden_states,
  1033. all_main_stream_attns,
  1034. all_ngram_stream_attns,
  1035. all_cross_attns,
  1036. ]
  1037. if v is not None
  1038. )
  1039. return ProphetNetDecoderModelOutput(
  1040. last_hidden_state=last_hidden_state,
  1041. last_hidden_state_ngram=last_hidden_state_ngram,
  1042. past_key_values=past_key_values,
  1043. hidden_states=all_main_stream_hidden_states,
  1044. hidden_states_ngram=all_ngram_stream_hidden_states,
  1045. attentions=all_main_stream_attns,
  1046. ngram_attentions=all_ngram_stream_attns,
  1047. cross_attentions=all_cross_attns,
  1048. )
  1049. def compute_buffered_relative_buckets(self, position_ids):
  1050. batch_size, sequence_length = position_ids.shape
  1051. position_ids = torch.arange(1, self.max_target_positions).to(position_ids.device).repeat(1, 1)
  1052. main_relative_buckets, predict_relative_buckets = compute_all_stream_relative_buckets(
  1053. self.num_buckets, self.relative_max_distance, position_ids
  1054. )
  1055. # buffer relative buckets
  1056. main_relative_buckets = main_relative_buckets[:, :sequence_length, :sequence_length].repeat(batch_size, 1, 1)
  1057. predict_relative_buckets = torch.cat(
  1058. [
  1059. predict_relative_buckets[:, :sequence_length, :sequence_length],
  1060. predict_relative_buckets[
  1061. :, :sequence_length, self.max_target_positions : self.max_target_positions + sequence_length
  1062. ],
  1063. ],
  1064. 2,
  1065. ).repeat(batch_size, 1, 1)
  1066. return main_relative_buckets, predict_relative_buckets
  1067. def prepare_attention_mask(self, hidden_states, attention_mask):
  1068. batch_size, seq_length = hidden_states.shape[:2]
  1069. # get causal mask
  1070. causal_mask = torch.full(
  1071. (seq_length, seq_length),
  1072. torch.finfo(hidden_states.dtype).min,
  1073. dtype=hidden_states.dtype,
  1074. device=hidden_states.device,
  1075. )
  1076. causal_mask = torch.triu(causal_mask, 1)
  1077. extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
  1078. (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
  1079. )
  1080. # add usual attention mask
  1081. if attention_mask is not None:
  1082. extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
  1083. extended_attention_mask = extended_causal_mask + extended_attention_mask
  1084. else:
  1085. extended_attention_mask = extended_causal_mask
  1086. return extended_attention_mask.to(hidden_states.dtype)
  1087. def prepare_predict_attention_mask(self, hidden_states, attention_mask):
  1088. batch_size, seq_length = hidden_states.shape[:2]
  1089. # get causal mask
  1090. predict_causal_mask = ngram_attention_bias(
  1091. self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype
  1092. )
  1093. predict_causal_mask = torch.cat(
  1094. [
  1095. predict_causal_mask[:, :seq_length, :seq_length],
  1096. predict_causal_mask[
  1097. :, :seq_length, self.max_target_positions : self.max_target_positions + seq_length
  1098. ],
  1099. ],
  1100. dim=-1,
  1101. )
  1102. extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
  1103. (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
  1104. )
  1105. # add usual attention mask
  1106. if attention_mask is not None:
  1107. extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
  1108. extended_attention_mask = extended_attention_mask.expand(
  1109. (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
  1110. )
  1111. # predicted stream attention_mask should always be 0
  1112. extended_attention_mask = torch.cat(
  1113. [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
  1114. )
  1115. extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
  1116. else:
  1117. extended_predict_attention_mask = extended_predict_causal_mask
  1118. return extended_predict_attention_mask.to(hidden_states.dtype)
  1119. @auto_docstring
  1120. class ProphetNetModel(ProphetNetPreTrainedModel):
  1121. _tied_weights_keys = {
  1122. "encoder.word_embeddings.weight": "word_embeddings.weight",
  1123. "decoder.word_embeddings.weight": "word_embeddings.weight",
  1124. }
  1125. def __init__(self, config: ProphetNetConfig):
  1126. super().__init__(config)
  1127. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  1128. encoder_config = copy.deepcopy(config)
  1129. encoder_config.use_cache = False
  1130. self.encoder = ProphetNetEncoder(encoder_config)
  1131. decoder_config = copy.deepcopy(config)
  1132. decoder_config.is_decoder = True
  1133. self.decoder = ProphetNetDecoder(decoder_config)
  1134. # Initialize weights and apply final processing
  1135. self.post_init()
  1136. def get_input_embeddings(self):
  1137. return self.word_embeddings
  1138. def set_input_embeddings(self, value):
  1139. self.word_embeddings = value
  1140. self.encoder.word_embeddings = self.word_embeddings
  1141. self.decoder.word_embeddings = self.word_embeddings
  1142. @auto_docstring
  1143. def forward(
  1144. self,
  1145. input_ids: torch.Tensor | None = None,
  1146. attention_mask: torch.Tensor | None = None,
  1147. decoder_input_ids: torch.Tensor | None = None,
  1148. decoder_attention_mask: torch.BoolTensor | None = None,
  1149. encoder_outputs: tuple | None = None,
  1150. past_key_values: Cache | None = None,
  1151. inputs_embeds: torch.Tensor | None = None,
  1152. decoder_inputs_embeds: torch.Tensor | None = None,
  1153. use_cache: bool | None = None,
  1154. output_attentions: bool | None = None,
  1155. output_hidden_states: bool | None = None,
  1156. return_dict: bool | None = None,
  1157. **kwargs,
  1158. ) -> tuple | ProphetNetSeq2SeqModelOutput:
  1159. r"""
  1160. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1161. Indices of decoder input sequence tokens in the vocabulary.
  1162. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1163. [`PreTrainedTokenizer.__call__`] for details.
  1164. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1165. ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
  1166. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1167. `past_key_values`).
  1168. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1169. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1170. be used by default.
  1171. Example:
  1172. ```python
  1173. >>> from transformers import AutoTokenizer, ProphetNetModel
  1174. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1175. >>> model = ProphetNetModel.from_pretrained("microsoft/prophetnet-large-uncased")
  1176. >>> input_ids = tokenizer(
  1177. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1178. ... ).input_ids # Batch size 1
  1179. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1180. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1181. >>> last_hidden_states = outputs.last_hidden_state # main stream hidden states
  1182. >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states
  1183. ```"""
  1184. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1185. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1186. output_hidden_states = (
  1187. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1188. )
  1189. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1190. if encoder_outputs is None:
  1191. encoder_outputs = self.encoder(
  1192. input_ids=input_ids,
  1193. attention_mask=attention_mask,
  1194. inputs_embeds=inputs_embeds,
  1195. output_attentions=output_attentions,
  1196. output_hidden_states=output_hidden_states,
  1197. return_dict=return_dict,
  1198. )
  1199. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1200. decoder_outputs = self.decoder(
  1201. input_ids=decoder_input_ids,
  1202. attention_mask=decoder_attention_mask,
  1203. encoder_hidden_states=encoder_outputs[0],
  1204. encoder_attention_mask=attention_mask,
  1205. past_key_values=past_key_values,
  1206. inputs_embeds=decoder_inputs_embeds,
  1207. output_attentions=output_attentions,
  1208. output_hidden_states=output_hidden_states,
  1209. use_cache=use_cache,
  1210. return_dict=return_dict,
  1211. )
  1212. if not return_dict:
  1213. return decoder_outputs + encoder_outputs
  1214. return ProphetNetSeq2SeqModelOutput(
  1215. last_hidden_state=decoder_outputs.last_hidden_state,
  1216. last_hidden_state_ngram=decoder_outputs.last_hidden_state_ngram,
  1217. past_key_values=decoder_outputs.past_key_values,
  1218. decoder_hidden_states=decoder_outputs.hidden_states,
  1219. decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram,
  1220. decoder_attentions=decoder_outputs.attentions,
  1221. decoder_ngram_attentions=decoder_outputs.ngram_attentions,
  1222. cross_attentions=decoder_outputs.cross_attentions,
  1223. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1224. encoder_hidden_states=encoder_outputs.hidden_states,
  1225. encoder_attentions=encoder_outputs.attentions,
  1226. )
  1227. @auto_docstring(
  1228. custom_intro="""
  1229. The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks.
  1230. """
  1231. )
  1232. class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMixin):
  1233. _tied_weights_keys = {
  1234. "lm_head.weight": "prophetnet.word_embeddings.weight",
  1235. }
  1236. def __init__(self, config: ProphetNetConfig):
  1237. super().__init__(config)
  1238. self.prophetnet = ProphetNetModel(config)
  1239. self.padding_idx = config.pad_token_id
  1240. self.disable_ngram_loss = config.disable_ngram_loss
  1241. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1242. # Initialize weights and apply final processing
  1243. self.post_init()
  1244. def get_input_embeddings(self):
  1245. return self.prophetnet.word_embeddings
  1246. @auto_docstring
  1247. def forward(
  1248. self,
  1249. input_ids: torch.Tensor | None = None,
  1250. attention_mask: torch.Tensor | None = None,
  1251. decoder_input_ids: torch.Tensor | None = None,
  1252. decoder_attention_mask: torch.BoolTensor | None = None,
  1253. encoder_outputs: torch.Tensor | None = None,
  1254. past_key_values: Cache | None = None,
  1255. inputs_embeds: torch.Tensor | None = None,
  1256. decoder_inputs_embeds: torch.Tensor | None = None,
  1257. labels: torch.Tensor | None = None,
  1258. use_cache: bool | None = None,
  1259. output_attentions: bool | None = None,
  1260. output_hidden_states: bool | None = None,
  1261. return_dict: bool | None = None,
  1262. **kwargs,
  1263. ) -> tuple | ProphetNetSeq2SeqLMOutput:
  1264. r"""
  1265. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1266. Indices of decoder input sequence tokens in the vocabulary.
  1267. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1268. [`PreTrainedTokenizer.__call__`] for details.
  1269. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1270. ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
  1271. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1272. `past_key_values`).
  1273. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1274. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1275. be used by default.
  1276. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1277. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1278. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1279. labels in `[0, ..., config.vocab_size]`
  1280. Example:
  1281. ```python
  1282. >>> from transformers import AutoTokenizer, ProphetNetForConditionalGeneration
  1283. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1284. >>> model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased")
  1285. >>> input_ids = tokenizer(
  1286. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1287. ... ).input_ids # Batch size 1
  1288. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1289. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1290. >>> logits_next_token = outputs.logits # logits to predict next token as usual
  1291. >>> logits_ngram_next_tokens = outputs.logits_ngram # logits to predict 2nd, 3rd, ... next tokens
  1292. ```"""
  1293. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1294. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1295. # get decoder inputs from shifting lm labels to the right
  1296. decoder_input_ids = self._shift_right(labels)
  1297. outputs = self.prophetnet(
  1298. input_ids=input_ids,
  1299. attention_mask=attention_mask,
  1300. decoder_input_ids=decoder_input_ids,
  1301. decoder_attention_mask=decoder_attention_mask,
  1302. encoder_outputs=encoder_outputs,
  1303. past_key_values=past_key_values,
  1304. inputs_embeds=inputs_embeds,
  1305. decoder_inputs_embeds=decoder_inputs_embeds,
  1306. use_cache=use_cache,
  1307. output_attentions=output_attentions,
  1308. output_hidden_states=output_hidden_states,
  1309. return_dict=return_dict,
  1310. )
  1311. batch_size, sequence_length = (
  1312. decoder_input_ids.shape if decoder_input_ids is not None else decoder_inputs_embeds.shape[:2]
  1313. )
  1314. predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1)
  1315. predict_logits = self.lm_head(predicting_streams)
  1316. logits = predict_logits[:, 0]
  1317. logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None
  1318. # To use .view in loss computation, make sure that logits is contiguous.
  1319. if not logits.is_contiguous():
  1320. logits = logits.contiguous()
  1321. loss = None
  1322. if labels is not None:
  1323. loss = self._compute_loss(predict_logits, labels)
  1324. if not return_dict:
  1325. all_logits = tuple(v for v in [logits, logits_ngram] if v is not None)
  1326. return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:]
  1327. else:
  1328. return ProphetNetSeq2SeqLMOutput(
  1329. loss=loss,
  1330. logits=logits,
  1331. logits_ngram=logits_ngram,
  1332. past_key_values=outputs.past_key_values,
  1333. decoder_hidden_states=outputs.decoder_hidden_states,
  1334. decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states,
  1335. decoder_attentions=outputs.decoder_attentions,
  1336. decoder_ngram_attentions=outputs.decoder_ngram_attentions,
  1337. cross_attentions=outputs.cross_attentions,
  1338. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1339. encoder_hidden_states=outputs.encoder_hidden_states,
  1340. encoder_attentions=outputs.encoder_attentions,
  1341. )
  1342. def _compute_loss(self, logits, labels, ignore_index=-100):
  1343. expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)
  1344. for i in range(self.config.ngram):
  1345. if i > 0 and self.disable_ngram_loss:
  1346. break
  1347. expend_targets[i, :, :] = labels
  1348. logits = logits.transpose(0, 1).contiguous()
  1349. lprobs = nn.functional.log_softmax(
  1350. logits.view(-1, logits.size(-1)),
  1351. dim=-1,
  1352. dtype=torch.float32,
  1353. )
  1354. loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
  1355. if self.config.eps > 0.0:
  1356. smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
  1357. non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
  1358. smooth_loss = smooth_loss[non_masked_tokens]
  1359. smooth_loss = smooth_loss.mean()
  1360. eps_i = self.config.eps / lprobs.size(-1)
  1361. loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss
  1362. return loss
  1363. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1364. return self._shift_right(labels)
  1365. def get_encoder(self, modality=None):
  1366. if modality is None:
  1367. return self.prophetnet.encoder
  1368. else:
  1369. return super().get_encoder(modality=modality)
  1370. @auto_docstring(
  1371. custom_intro="""
  1372. The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal
  1373. """
  1374. )
  1375. class ProphetNetForCausalLM(ProphetNetPreTrainedModel, GenerationMixin):
  1376. _tied_weights_keys = {
  1377. "lm_head.weight": "prophetnet.word_embeddings.weight",
  1378. "prophetnet.decoder.word_embeddings.weight": "prophetnet.word_embeddings.weight",
  1379. }
  1380. def __init__(self, config: ProphetNetConfig):
  1381. # set config for CLM
  1382. config = copy.deepcopy(config)
  1383. config.is_decoder = True
  1384. config.is_encoder_decoder = False
  1385. super().__init__(config)
  1386. self.prophetnet = ProphetNetDecoderWrapper(config)
  1387. self.padding_idx = config.pad_token_id
  1388. self.disable_ngram_loss = config.disable_ngram_loss
  1389. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1390. # Initialize weights and apply final processing
  1391. self.post_init()
  1392. def get_input_embeddings(self):
  1393. return self.prophetnet.decoder.word_embeddings
  1394. def set_input_embeddings(self, value):
  1395. self.prophetnet.decoder.word_embeddings = value
  1396. @auto_docstring
  1397. def forward(
  1398. self,
  1399. input_ids: torch.Tensor | None = None,
  1400. attention_mask: torch.Tensor | None = None,
  1401. encoder_hidden_states: torch.Tensor | None = None,
  1402. encoder_attention_mask: torch.Tensor | None = None,
  1403. past_key_values: Cache | None = None,
  1404. inputs_embeds: torch.Tensor | None = None,
  1405. labels: torch.Tensor | None = None,
  1406. use_cache: bool | None = None,
  1407. output_attentions: bool | None = None,
  1408. output_hidden_states: bool | None = None,
  1409. return_dict: bool | None = None,
  1410. **kwargs,
  1411. ) -> tuple | ProphetNetDecoderLMOutput:
  1412. r"""
  1413. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1414. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1415. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1416. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  1417. Example:
  1418. ```python
  1419. >>> from transformers import AutoTokenizer, ProphetNetForCausalLM
  1420. >>> import torch
  1421. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1422. >>> model = ProphetNetForCausalLM.from_pretrained("microsoft/prophetnet-large-uncased")
  1423. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1424. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1425. >>> outputs = model(**inputs)
  1426. >>> logits = outputs.logits
  1427. >>> # Model can also be used with EncoderDecoder framework
  1428. >>> from transformers import BertTokenizer, EncoderDecoderModel, AutoTokenizer
  1429. >>> import torch
  1430. >>> tokenizer_enc = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")
  1431. >>> tokenizer_dec = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1432. >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
  1433. ... "google-bert/bert-large-uncased", "microsoft/prophetnet-large-uncased"
  1434. ... )
  1435. >>> ARTICLE = (
  1436. ... "the us state department said wednesday it had received no "
  1437. ... "formal word from bolivia that it was expelling the us ambassador there "
  1438. ... "but said the charges made against him are `` baseless ."
  1439. ... )
  1440. >>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids
  1441. >>> labels = tokenizer_dec(
  1442. ... "us rejects charges against its ambassador in bolivia", return_tensors="pt"
  1443. ... ).input_ids
  1444. >>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:])
  1445. >>> loss = outputs.loss
  1446. ```"""
  1447. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1448. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1449. outputs = self.prophetnet.decoder(
  1450. input_ids=input_ids,
  1451. attention_mask=attention_mask,
  1452. encoder_hidden_states=encoder_hidden_states,
  1453. encoder_attention_mask=encoder_attention_mask,
  1454. past_key_values=past_key_values,
  1455. inputs_embeds=inputs_embeds,
  1456. use_cache=use_cache,
  1457. output_attentions=output_attentions,
  1458. output_hidden_states=output_hidden_states,
  1459. return_dict=return_dict,
  1460. )
  1461. batch_size, sequence_length = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
  1462. predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1)
  1463. predict_logits = self.lm_head(predicting_streams)
  1464. logits = predict_logits[:, 0]
  1465. logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None
  1466. loss = None
  1467. if labels is not None:
  1468. loss = self._compute_loss(predict_logits, labels)
  1469. if not return_dict:
  1470. all_logits = tuple(v for v in [logits, logits_ngram] if v is not None)
  1471. return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:]
  1472. else:
  1473. return ProphetNetDecoderLMOutput(
  1474. loss=loss,
  1475. logits=logits,
  1476. logits_ngram=logits_ngram,
  1477. past_key_values=outputs.past_key_values,
  1478. hidden_states=outputs.hidden_states,
  1479. hidden_states_ngram=outputs.hidden_states_ngram,
  1480. attentions=outputs.attentions,
  1481. ngram_attentions=outputs.ngram_attentions,
  1482. cross_attentions=outputs.cross_attentions,
  1483. )
  1484. def _compute_loss(self, logits, labels, ignore_index=-100):
  1485. expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)
  1486. for i in range(self.config.ngram):
  1487. if i > 0 and self.disable_ngram_loss:
  1488. break
  1489. expend_targets[i, :, :] = labels
  1490. logits = logits.transpose(0, 1).contiguous()
  1491. lprobs = nn.functional.log_softmax(
  1492. logits.view(-1, logits.size(-1)),
  1493. dim=-1,
  1494. dtype=torch.float32,
  1495. )
  1496. loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
  1497. if self.config.eps > 0.0:
  1498. smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
  1499. non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
  1500. smooth_loss = smooth_loss[non_masked_tokens]
  1501. smooth_loss = smooth_loss.mean()
  1502. eps_i = self.config.eps / lprobs.size(-1)
  1503. loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss
  1504. return loss
  1505. class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
  1506. """
  1507. This is a wrapper class, so that [`ProphetNetForCausalLM`] can correctly be loaded from pretrained prophetnet
  1508. classes.
  1509. """
  1510. _tied_weights_keys = {
  1511. "decoder.word_embeddings.weight": "word_embeddings.weight",
  1512. }
  1513. def __init__(self, config: ProphetNetConfig):
  1514. super().__init__(config)
  1515. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  1516. self.decoder = ProphetNetDecoder(config)
  1517. # Initialize weights and apply final processing
  1518. self.post_init()
  1519. def forward(self, *args, **kwargs):
  1520. return self.decoder(*args, **kwargs)
  1521. __all__ = [
  1522. "ProphetNetDecoder",
  1523. "ProphetNetEncoder",
  1524. "ProphetNetForCausalLM",
  1525. "ProphetNetForConditionalGeneration",
  1526. "ProphetNetModel",
  1527. "ProphetNetPreTrainedModel",
  1528. ]