modeling_led.py 104 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202
  1. # Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch LED model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
  27. from ...modeling_utils import PreTrainedModel
  28. from ...utils import ModelOutput, auto_docstring, logging
  29. from .configuration_led import LEDConfig
  30. logger = logging.get_logger(__name__)
  31. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  32. """
  33. Shift input ids one token to the right.
  34. """
  35. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  36. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  37. shifted_input_ids[:, 0] = decoder_start_token_id
  38. if pad_token_id is None:
  39. raise ValueError("config.pad_token_id has to be defined.")
  40. # replace possible -100 values in labels by `pad_token_id`
  41. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  42. return shifted_input_ids
  43. def _prepare_4d_attention_mask_inverted(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
  44. """
  45. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  46. """
  47. bsz, src_len = mask.size()
  48. tgt_len = tgt_len if tgt_len is not None else src_len
  49. expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
  50. inverted_mask = 1.0 - expanded_mask
  51. expanded_attention_mask = inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
  52. # make sure that global_attn_mask is positive
  53. expanded_attention_mask = expanded_attention_mask * inverted_mask
  54. return expanded_attention_mask
  55. class LEDLearnedPositionalEmbedding(nn.Embedding):
  56. """
  57. This module learns positional embeddings up to a fixed maximum size.
  58. """
  59. def __init__(self, num_embeddings: int, embedding_dim: int):
  60. super().__init__(num_embeddings, embedding_dim)
  61. def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
  62. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  63. bsz, seq_len = input_ids_shape[:2]
  64. positions = torch.arange(
  65. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  66. )
  67. return super().forward(positions)
  68. # Copied from transformers.models.longformer.modeling_longformer.LongformerSelfAttention with Longformer->LEDEncoder
  69. class LEDEncoderSelfAttention(nn.Module):
  70. def __init__(self, config, layer_id):
  71. super().__init__()
  72. if config.hidden_size % config.num_attention_heads != 0:
  73. raise ValueError(
  74. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  75. f"heads ({config.num_attention_heads})"
  76. )
  77. self.num_heads = config.num_attention_heads
  78. self.head_dim = int(config.hidden_size / config.num_attention_heads)
  79. self.embed_dim = config.hidden_size
  80. self.query = nn.Linear(config.hidden_size, self.embed_dim)
  81. self.key = nn.Linear(config.hidden_size, self.embed_dim)
  82. self.value = nn.Linear(config.hidden_size, self.embed_dim)
  83. # separate projection layers for tokens with global attention
  84. self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
  85. self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
  86. self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
  87. self.dropout = config.attention_probs_dropout_prob
  88. self.layer_id = layer_id
  89. attention_window = config.attention_window[self.layer_id]
  90. assert attention_window % 2 == 0, (
  91. f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
  92. )
  93. assert attention_window > 0, (
  94. f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
  95. )
  96. self.one_sided_attn_window_size = attention_window // 2
  97. self.config = config
  98. def forward(
  99. self,
  100. hidden_states,
  101. attention_mask=None,
  102. is_index_masked=None,
  103. is_index_global_attn=None,
  104. is_global_attn=None,
  105. output_attentions=False,
  106. ):
  107. """
  108. [`LEDEncoderSelfAttention`] expects *len(hidden_states)* to be multiple of *attention_window*. Padding to
  109. *attention_window* happens in [`LEDEncoderModel.forward`] to avoid redoing the padding on each layer.
  110. The *attention_mask* is changed in [`LEDEncoderModel.forward`] from 0, 1, 2 to:
  111. - -10000: no attention
  112. - 0: local attention
  113. - +10000: global attention
  114. """
  115. hidden_states = hidden_states.transpose(0, 1)
  116. # project hidden states
  117. query_vectors = self.query(hidden_states)
  118. key_vectors = self.key(hidden_states)
  119. value_vectors = self.value(hidden_states)
  120. seq_len, batch_size, embed_dim = hidden_states.size()
  121. assert embed_dim == self.embed_dim, (
  122. f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"
  123. )
  124. # normalize query
  125. query_vectors /= math.sqrt(self.head_dim)
  126. query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
  127. key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
  128. attn_scores = self._sliding_chunks_query_key_matmul(
  129. query_vectors, key_vectors, self.one_sided_attn_window_size
  130. )
  131. # values to pad for attention probs
  132. remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
  133. # cast to fp32/fp16 then replace 1's with -inf
  134. float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
  135. remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min
  136. )
  137. # diagonal mask with zeros everywhere and -inf inplace of padding
  138. diagonal_mask = self._sliding_chunks_query_key_matmul(
  139. float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
  140. )
  141. # pad local attention probs
  142. attn_scores += diagonal_mask
  143. assert list(attn_scores.size()) == [
  144. batch_size,
  145. seq_len,
  146. self.num_heads,
  147. self.one_sided_attn_window_size * 2 + 1,
  148. ], (
  149. f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
  150. f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
  151. )
  152. # compute local attention probs from global attention keys and contact over window dim
  153. if is_global_attn:
  154. # compute global attn indices required through out forward fn
  155. (
  156. max_num_global_attn_indices,
  157. is_index_global_attn_nonzero,
  158. is_local_index_global_attn_nonzero,
  159. is_local_index_no_global_attn_nonzero,
  160. ) = self._get_global_attn_indices(is_index_global_attn)
  161. # calculate global attn probs from global key
  162. global_key_attn_scores = self._concat_with_global_key_attn_probs(
  163. query_vectors=query_vectors,
  164. key_vectors=key_vectors,
  165. max_num_global_attn_indices=max_num_global_attn_indices,
  166. is_index_global_attn_nonzero=is_index_global_attn_nonzero,
  167. is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
  168. is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
  169. )
  170. # concat to local_attn_probs
  171. # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
  172. attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
  173. # free memory
  174. del global_key_attn_scores
  175. attn_probs = nn.functional.softmax(
  176. attn_scores, dim=-1, dtype=torch.float32
  177. ) # use fp32 for numerical stability
  178. # softmax sometimes inserts NaN if all positions are masked, replace them with 0
  179. attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
  180. attn_probs = attn_probs.type_as(attn_scores)
  181. # free memory
  182. del attn_scores
  183. # apply dropout
  184. attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)
  185. value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
  186. # compute local attention output with global attention value and add
  187. if is_global_attn:
  188. # compute sum of global and local attn
  189. attn_output = self._compute_attn_output_with_global_indices(
  190. value_vectors=value_vectors,
  191. attn_probs=attn_probs,
  192. max_num_global_attn_indices=max_num_global_attn_indices,
  193. is_index_global_attn_nonzero=is_index_global_attn_nonzero,
  194. is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
  195. )
  196. else:
  197. # compute local attn only
  198. attn_output = self._sliding_chunks_matmul_attn_probs_value(
  199. attn_probs, value_vectors, self.one_sided_attn_window_size
  200. )
  201. assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
  202. attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
  203. # compute value for global attention and overwrite to attention output
  204. # TODO: remove the redundant computation
  205. if is_global_attn:
  206. global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
  207. hidden_states=hidden_states,
  208. max_num_global_attn_indices=max_num_global_attn_indices,
  209. is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
  210. is_index_global_attn_nonzero=is_index_global_attn_nonzero,
  211. is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
  212. is_index_masked=is_index_masked,
  213. )
  214. # get only non zero global attn output
  215. nonzero_global_attn_output = global_attn_output[
  216. is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
  217. ]
  218. # overwrite values with global attention
  219. attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
  220. len(is_local_index_global_attn_nonzero[0]), -1
  221. )
  222. # The attention weights for tokens with global attention are
  223. # just filler values, they were never used to compute the output.
  224. # Fill with 0 now, the correct values are in 'global_attn_probs'.
  225. attn_probs[is_index_global_attn_nonzero] = 0
  226. outputs = (attn_output.transpose(0, 1),)
  227. if output_attentions:
  228. outputs += (attn_probs,)
  229. return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs
  230. @staticmethod
  231. def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
  232. """pads rows and then flips rows and columns"""
  233. hidden_states_padded = nn.functional.pad(
  234. hidden_states_padded, padding
  235. ) # padding value is not important because it will be overwritten
  236. hidden_states_padded = hidden_states_padded.view(
  237. *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2)
  238. )
  239. return hidden_states_padded
  240. @staticmethod
  241. def _pad_and_diagonalize(chunked_hidden_states):
  242. """
  243. shift every row 1 step right, converting columns into diagonals.
  244. Example:
  245. ```python
  246. chunked_hidden_states: [
  247. 0.4983,
  248. 2.6918,
  249. -0.0071,
  250. 1.0492,
  251. -1.8348,
  252. 0.7672,
  253. 0.2986,
  254. 0.0285,
  255. -0.7584,
  256. 0.4206,
  257. -0.0405,
  258. 0.1599,
  259. 2.0514,
  260. -1.1600,
  261. 0.5372,
  262. 0.2629,
  263. ]
  264. window_overlap = num_rows = 4
  265. ```
  266. (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
  267. 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206,
  268. -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
  269. """
  270. total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
  271. chunked_hidden_states = nn.functional.pad(
  272. chunked_hidden_states, (0, window_overlap + 1)
  273. ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
  274. chunked_hidden_states = chunked_hidden_states.view(
  275. total_num_heads, num_chunks, -1
  276. ) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap
  277. chunked_hidden_states = chunked_hidden_states[
  278. :, :, :-window_overlap
  279. ] # total_num_heads x num_chunks x window_overlap*window_overlap
  280. chunked_hidden_states = chunked_hidden_states.view(
  281. total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim
  282. )
  283. chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
  284. return chunked_hidden_states
  285. @staticmethod
  286. def _chunk(hidden_states, window_overlap, onnx_export: bool = False):
  287. """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
  288. if not onnx_export:
  289. # non-overlapping chunks of size = 2w
  290. hidden_states = hidden_states.view(
  291. hidden_states.size(0),
  292. torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"),
  293. window_overlap * 2,
  294. hidden_states.size(2),
  295. )
  296. # use `as_strided` to make the chunks overlap with an overlap size = window_overlap
  297. chunk_size = list(hidden_states.size())
  298. chunk_size[1] = chunk_size[1] * 2 - 1
  299. chunk_stride = list(hidden_states.stride())
  300. chunk_stride[1] = chunk_stride[1] // 2
  301. return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
  302. # When exporting to ONNX, use this separate logic
  303. # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
  304. # TODO replace this with
  305. # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
  306. # once `unfold` is supported
  307. # the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow
  308. chunk_size = [
  309. hidden_states.size(0),
  310. torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1,
  311. window_overlap * 2,
  312. hidden_states.size(2),
  313. ]
  314. overlapping_chunks = torch.empty(chunk_size, device=hidden_states.device)
  315. for chunk in range(chunk_size[1]):
  316. overlapping_chunks[:, chunk, :, :] = hidden_states[
  317. :, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
  318. ]
  319. return overlapping_chunks
  320. @staticmethod
  321. def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
  322. beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
  323. beginning_mask = beginning_mask_2d[None, :, None, :]
  324. ending_mask = beginning_mask.flip(dims=(1, 3))
  325. beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
  326. beginning_mask = beginning_mask.expand(beginning_input.size())
  327. input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like(
  328. beginning_input, -float("inf")
  329. ).where(beginning_mask.bool(), beginning_input)
  330. ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
  331. ending_mask = ending_mask.expand(ending_input.size())
  332. input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like(
  333. ending_input, -float("inf")
  334. ).where(ending_mask.bool(), ending_input)
  335. def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
  336. """
  337. Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
  338. implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained LEDEncoder) with an
  339. overlap of size window_overlap
  340. """
  341. batch_size, seq_len, num_heads, head_dim = query.size()
  342. assert seq_len % (window_overlap * 2) == 0, (
  343. f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
  344. )
  345. assert query.size() == key.size()
  346. chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
  347. # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
  348. query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
  349. key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
  350. query = self._chunk(query, window_overlap, getattr(self.config, "onnx_export", False))
  351. key = self._chunk(key, window_overlap, getattr(self.config, "onnx_export", False))
  352. # matrix multiplication
  353. # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
  354. # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
  355. # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
  356. diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
  357. # convert diagonals into columns
  358. diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
  359. diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)
  360. )
  361. # allocate space for the overall attention matrix where the chunks are combined. The last dimension
  362. # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
  363. # window_overlap previous words). The following column is attention score from each word to itself, then
  364. # followed by window_overlap columns for the upper triangle.
  365. diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
  366. (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
  367. )
  368. # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
  369. # - copying the main diagonal and the upper triangle
  370. diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
  371. :, :, :window_overlap, : window_overlap + 1
  372. ]
  373. diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
  374. :, -1, window_overlap:, : window_overlap + 1
  375. ]
  376. # - copying the lower triangle
  377. diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
  378. :, :, -(window_overlap + 1) : -1, window_overlap + 1 :
  379. ]
  380. diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
  381. :, 0, : window_overlap - 1, 1 - window_overlap :
  382. ]
  383. # separate batch_size and num_heads dimensions again
  384. diagonal_attention_scores = diagonal_attention_scores.view(
  385. batch_size, num_heads, seq_len, 2 * window_overlap + 1
  386. ).transpose(2, 1)
  387. self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
  388. return diagonal_attention_scores
  389. def _sliding_chunks_matmul_attn_probs_value(
  390. self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int
  391. ):
  392. """
  393. Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
  394. same shape as `attn_probs`
  395. """
  396. batch_size, seq_len, num_heads, head_dim = value.size()
  397. assert seq_len % (window_overlap * 2) == 0
  398. assert attn_probs.size()[:3] == value.size()[:3]
  399. assert attn_probs.size(3) == 2 * window_overlap + 1
  400. chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
  401. # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
  402. chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
  403. batch_size * num_heads,
  404. torch.div(seq_len, window_overlap, rounding_mode="trunc"),
  405. window_overlap,
  406. 2 * window_overlap + 1,
  407. )
  408. # group batch_size and num_heads dimensions into one
  409. value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
  410. # pad seq_len with w at the beginning of the sequence and another window overlap at the end
  411. padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
  412. # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
  413. chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
  414. chunked_value_stride = padded_value.stride()
  415. chunked_value_stride = (
  416. chunked_value_stride[0],
  417. window_overlap * chunked_value_stride[1],
  418. chunked_value_stride[1],
  419. chunked_value_stride[2],
  420. )
  421. chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
  422. chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
  423. context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
  424. return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
  425. @staticmethod
  426. def _get_global_attn_indices(is_index_global_attn):
  427. """compute global attn indices required throughout forward pass"""
  428. # helper variable
  429. num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
  430. # max number of global attn indices in batch
  431. max_num_global_attn_indices = num_global_attn_indices.max()
  432. # indices of global attn
  433. is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True)
  434. # helper variable
  435. is_local_index_global_attn = torch.arange(
  436. max_num_global_attn_indices, device=is_index_global_attn.device
  437. ) < num_global_attn_indices.unsqueeze(dim=-1)
  438. # location of the non-padding values within global attention indices
  439. is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True)
  440. # location of the padding values within global attention indices
  441. is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True)
  442. return (
  443. max_num_global_attn_indices,
  444. is_index_global_attn_nonzero,
  445. is_local_index_global_attn_nonzero,
  446. is_local_index_no_global_attn_nonzero,
  447. )
  448. def _concat_with_global_key_attn_probs(
  449. self,
  450. key_vectors,
  451. query_vectors,
  452. max_num_global_attn_indices,
  453. is_index_global_attn_nonzero,
  454. is_local_index_global_attn_nonzero,
  455. is_local_index_no_global_attn_nonzero,
  456. ):
  457. batch_size = key_vectors.shape[0]
  458. # create only global key vectors
  459. key_vectors_only_global = key_vectors.new_zeros(
  460. batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
  461. )
  462. key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero]
  463. # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
  464. attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
  465. # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
  466. attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
  467. attn_probs_from_global_key[
  468. is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
  469. ] = torch.finfo(attn_probs_from_global_key.dtype).min
  470. attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
  471. return attn_probs_from_global_key
  472. def _compute_attn_output_with_global_indices(
  473. self,
  474. value_vectors,
  475. attn_probs,
  476. max_num_global_attn_indices,
  477. is_index_global_attn_nonzero,
  478. is_local_index_global_attn_nonzero,
  479. ):
  480. batch_size = attn_probs.shape[0]
  481. # cut local attn probs to global only
  482. attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices)
  483. # get value vectors for global only
  484. value_vectors_only_global = value_vectors.new_zeros(
  485. batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
  486. )
  487. value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero]
  488. # use `matmul` because `einsum` crashes sometimes with fp16
  489. # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
  490. # compute attn output only global
  491. attn_output_only_global = torch.matmul(
  492. attn_probs_only_global.transpose(1, 2).clone(), value_vectors_only_global.transpose(1, 2).clone()
  493. ).transpose(1, 2)
  494. # reshape attn probs
  495. attn_probs_without_global = attn_probs.narrow(
  496. -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices
  497. ).contiguous()
  498. # compute attn output with global
  499. attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
  500. attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
  501. )
  502. return attn_output_only_global + attn_output_without_global
  503. def _compute_global_attn_output_from_hidden(
  504. self,
  505. hidden_states,
  506. max_num_global_attn_indices,
  507. is_local_index_global_attn_nonzero,
  508. is_index_global_attn_nonzero,
  509. is_local_index_no_global_attn_nonzero,
  510. is_index_masked,
  511. ):
  512. seq_len, batch_size = hidden_states.shape[:2]
  513. # prepare global hidden states
  514. global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim)
  515. global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[
  516. is_index_global_attn_nonzero[::-1]
  517. ]
  518. # global key, query, value
  519. global_query_vectors_only_global = self.query_global(global_attn_hidden_states)
  520. global_key_vectors = self.key_global(hidden_states)
  521. global_value_vectors = self.value_global(hidden_states)
  522. # normalize
  523. global_query_vectors_only_global /= math.sqrt(self.head_dim)
  524. # reshape
  525. global_query_vectors_only_global = (
  526. global_query_vectors_only_global.contiguous()
  527. .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim)
  528. .transpose(0, 1)
  529. ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim)
  530. global_key_vectors = (
  531. global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
  532. ) # batch_size * self.num_heads, seq_len, head_dim)
  533. global_value_vectors = (
  534. global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
  535. ) # batch_size * self.num_heads, seq_len, head_dim)
  536. # compute attn scores
  537. global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2))
  538. assert list(global_attn_scores.size()) == [
  539. batch_size * self.num_heads,
  540. max_num_global_attn_indices,
  541. seq_len,
  542. ], (
  543. "global_attn_scores have the wrong size. Size should be"
  544. f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
  545. f" {global_attn_scores.size()}."
  546. )
  547. global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
  548. # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
  549. global_attn_scores = global_attn_scores.transpose(1, 2)
  550. global_attn_scores[
  551. is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
  552. ] = torch.finfo(global_attn_scores.dtype).min
  553. global_attn_scores = global_attn_scores.transpose(1, 2)
  554. global_attn_scores = global_attn_scores.masked_fill(
  555. is_index_masked[:, None, None, :],
  556. torch.finfo(global_attn_scores.dtype).min,
  557. )
  558. global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
  559. # compute global attn probs
  560. global_attn_probs_float = nn.functional.softmax(
  561. global_attn_scores, dim=-1, dtype=torch.float32
  562. ) # use fp32 for numerical stability
  563. global_attn_probs = nn.functional.dropout(
  564. global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
  565. )
  566. # global attn output
  567. global_attn_output = torch.bmm(global_attn_probs, global_value_vectors)
  568. assert list(global_attn_output.size()) == [
  569. batch_size * self.num_heads,
  570. max_num_global_attn_indices,
  571. self.head_dim,
  572. ], (
  573. "global_attn_output tensor has the wrong size. Size should be"
  574. f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
  575. f" {global_attn_output.size()}."
  576. )
  577. global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
  578. global_attn_output = global_attn_output.view(
  579. batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim
  580. )
  581. return global_attn_output, global_attn_probs
  582. class LEDEncoderAttention(nn.Module):
  583. def __init__(self, config, layer_id):
  584. super().__init__()
  585. self.longformer_self_attn = LEDEncoderSelfAttention(config, layer_id=layer_id)
  586. self.output = nn.Linear(config.d_model, config.d_model)
  587. def forward(
  588. self,
  589. hidden_states: torch.Tensor,
  590. attention_mask: torch.Tensor | None = None,
  591. is_index_masked: torch.Tensor | None = None,
  592. is_index_global_attn: torch.Tensor | None = None,
  593. is_global_attn: bool | None = None,
  594. output_attentions: bool = False,
  595. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  596. """Input shape: Batch x Time x Channel"""
  597. self_outputs = self.longformer_self_attn(
  598. hidden_states=hidden_states,
  599. attention_mask=attention_mask,
  600. is_index_masked=is_index_masked,
  601. is_index_global_attn=is_index_global_attn,
  602. is_global_attn=is_global_attn,
  603. output_attentions=output_attentions,
  604. )
  605. attn_output = self.output(self_outputs[0])
  606. outputs = (attn_output,) + self_outputs[1:]
  607. return outputs
  608. class LEDDecoderAttention(nn.Module):
  609. """Multi-headed attention from 'Attention Is All You Need' paper"""
  610. def __init__(
  611. self,
  612. embed_dim: int,
  613. num_heads: int,
  614. dropout: float | None = 0.0,
  615. is_decoder: bool | None = False,
  616. bias: bool | None = True,
  617. layer_idx: bool | None = None,
  618. ):
  619. super().__init__()
  620. self.embed_dim = embed_dim
  621. self.num_heads = num_heads
  622. self.dropout = dropout
  623. self.head_dim = embed_dim // num_heads
  624. if self.head_dim * num_heads != self.embed_dim:
  625. raise ValueError(
  626. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  627. f" {num_heads})."
  628. )
  629. self.scaling = self.head_dim**-0.5
  630. self.is_decoder = is_decoder
  631. self.layer_idx = layer_idx
  632. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  633. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  634. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  635. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  636. def forward(
  637. self,
  638. hidden_states: torch.Tensor,
  639. key_value_states: torch.Tensor | None = None,
  640. past_key_values: Cache | None = None,
  641. attention_mask: torch.Tensor | None = None,
  642. output_attentions: bool = False,
  643. ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
  644. """Input shape: Batch x Time x Channel"""
  645. # if key_value_states are provided this layer is used as a cross-attention layer
  646. # for the decoder
  647. is_cross_attention = key_value_states is not None
  648. bsz, tgt_len, embed_dim = hidden_states.size()
  649. # get query proj
  650. query_states = self.q_proj(hidden_states) * self.scaling
  651. is_updated = False
  652. if past_key_values is not None:
  653. if isinstance(past_key_values, EncoderDecoderCache):
  654. is_updated = past_key_values.is_updated.get(self.layer_idx)
  655. if is_cross_attention:
  656. # after the first generated id, we can subsequently re-use all key/value_states from cache
  657. curr_past_key_values = past_key_values.cross_attention_cache
  658. else:
  659. curr_past_key_values = past_key_values.self_attention_cache
  660. else:
  661. curr_past_key_values = past_key_values
  662. current_states = key_value_states if is_cross_attention else hidden_states
  663. if is_cross_attention and past_key_values is not None and is_updated:
  664. # reuse k,v, cross_attentions
  665. key_states = curr_past_key_values.layers[self.layer_idx].keys
  666. value_states = curr_past_key_values.layers[self.layer_idx].values
  667. else:
  668. key_states = self.k_proj(current_states)
  669. value_states = self.v_proj(current_states)
  670. key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  671. value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  672. if past_key_values is not None:
  673. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  674. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  675. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  676. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  677. past_key_values.is_updated[self.layer_idx] = True
  678. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  679. query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  680. query_states = query_states.reshape(*proj_shape)
  681. key_states = key_states.reshape(*proj_shape)
  682. value_states = value_states.reshape(*proj_shape)
  683. src_len = key_states.size(1)
  684. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  685. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  686. raise ValueError(
  687. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  688. f" {attn_weights.size()}"
  689. )
  690. if attention_mask is not None:
  691. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  692. raise ValueError(
  693. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  694. )
  695. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  696. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  697. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  698. if output_attentions:
  699. # this operation is a bit awkward, but it's required to
  700. # make sure that attn_weights keeps its gradient.
  701. # In order to do so, attn_weights have to be reshaped
  702. # twice and have to be reused in the following
  703. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  704. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  705. else:
  706. attn_weights_reshaped = None
  707. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  708. attn_output = torch.bmm(attn_probs, value_states)
  709. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  710. raise ValueError(
  711. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  712. f" {attn_output.size()}"
  713. )
  714. attn_output = (
  715. attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  716. .transpose(1, 2)
  717. .reshape(bsz, tgt_len, embed_dim)
  718. )
  719. attn_output = self.out_proj(attn_output)
  720. return attn_output, attn_weights_reshaped, past_key_values
  721. class LEDEncoderLayer(GradientCheckpointingLayer):
  722. def __init__(self, config: LEDConfig, layer_id: int):
  723. super().__init__()
  724. self.embed_dim = config.d_model
  725. self.self_attn = LEDEncoderAttention(config, layer_id)
  726. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  727. self.dropout = config.dropout
  728. self.activation_fn = ACT2FN[config.activation_function]
  729. self.activation_dropout = config.activation_dropout
  730. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  731. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  732. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  733. def forward(
  734. self,
  735. hidden_states: torch.Tensor,
  736. attention_mask: torch.Tensor,
  737. is_index_masked=None,
  738. is_index_global_attn=None,
  739. is_global_attn=None,
  740. output_attentions=False,
  741. ):
  742. """
  743. Args:
  744. hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
  745. attention_mask (`torch.FloatTensor`): attention mask of size
  746. *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
  747. """
  748. residual = hidden_states
  749. attn_outputs = self.self_attn(
  750. hidden_states=hidden_states,
  751. attention_mask=attention_mask,
  752. is_index_masked=is_index_masked,
  753. is_index_global_attn=is_index_global_attn,
  754. is_global_attn=is_global_attn,
  755. output_attentions=output_attentions,
  756. )
  757. hidden_states = attn_outputs[0]
  758. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  759. hidden_states = residual + hidden_states
  760. hidden_states = self.self_attn_layer_norm(hidden_states)
  761. residual = hidden_states
  762. hidden_states = self.activation_fn(self.fc1(hidden_states))
  763. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  764. hidden_states = self.fc2(hidden_states)
  765. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  766. hidden_states = residual + hidden_states
  767. hidden_states = self.final_layer_norm(hidden_states)
  768. if hidden_states.dtype == torch.float16 and not torch.isfinite(hidden_states).all():
  769. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  770. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  771. return (hidden_states,) + attn_outputs[1:]
  772. class LEDDecoderLayer(GradientCheckpointingLayer):
  773. def __init__(self, config: LEDConfig, layer_idx=None):
  774. super().__init__()
  775. self.embed_dim = config.d_model
  776. self.self_attn = LEDDecoderAttention(
  777. embed_dim=self.embed_dim,
  778. num_heads=config.decoder_attention_heads,
  779. dropout=config.attention_dropout,
  780. is_decoder=True,
  781. layer_idx=layer_idx,
  782. )
  783. self.dropout = config.dropout
  784. self.activation_fn = ACT2FN[config.activation_function]
  785. self.activation_dropout = config.activation_dropout
  786. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  787. self.encoder_attn = LEDDecoderAttention(
  788. self.embed_dim,
  789. config.decoder_attention_heads,
  790. dropout=config.attention_dropout,
  791. is_decoder=True,
  792. layer_idx=layer_idx,
  793. )
  794. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  795. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  796. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  797. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  798. def forward(
  799. self,
  800. hidden_states: torch.Tensor,
  801. attention_mask: torch.Tensor | None = None,
  802. encoder_hidden_states: torch.Tensor | None = None,
  803. encoder_attention_mask: torch.Tensor | None = None,
  804. past_key_values: Cache | None = None,
  805. output_attentions: bool | None = False,
  806. use_cache: bool | None = True,
  807. **kwargs,
  808. ):
  809. """
  810. Args:
  811. hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
  812. attention_mask (`torch.FloatTensor`): attention mask of size
  813. *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
  814. encoder_hidden_states (`torch.FloatTensor`):
  815. cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
  816. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  817. *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
  818. past_key_values (`Cache`): cached past key and value projection states
  819. output_attentions (`bool`): Whether the base model outputs attentions.
  820. This requires the attentions tensor to be reshaped in this function.
  821. """
  822. residual = hidden_states
  823. # Self-Attention
  824. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  825. hidden_states=hidden_states,
  826. past_key_values=past_key_values,
  827. attention_mask=attention_mask,
  828. output_attentions=output_attentions,
  829. )
  830. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  831. hidden_states = residual + hidden_states
  832. hidden_states = self.self_attn_layer_norm(hidden_states)
  833. # Cross-Attention Block
  834. cross_attn_present_key_value = None
  835. cross_attn_weights = None
  836. if encoder_hidden_states is not None:
  837. residual = hidden_states
  838. hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
  839. hidden_states=hidden_states,
  840. key_value_states=encoder_hidden_states,
  841. attention_mask=encoder_attention_mask,
  842. past_key_values=past_key_values,
  843. output_attentions=output_attentions,
  844. )
  845. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  846. hidden_states = residual + hidden_states
  847. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  848. # Fully Connected
  849. residual = hidden_states
  850. hidden_states = self.activation_fn(self.fc1(hidden_states))
  851. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  852. hidden_states = self.fc2(hidden_states)
  853. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  854. hidden_states = residual + hidden_states
  855. hidden_states = self.final_layer_norm(hidden_states)
  856. outputs = (hidden_states,)
  857. if output_attentions:
  858. outputs += (self_attn_weights, cross_attn_weights)
  859. if use_cache:
  860. outputs += (past_key_values,)
  861. return outputs
  862. class LEDClassificationHead(nn.Module):
  863. """Head for sentence-level classification tasks."""
  864. def __init__(
  865. self,
  866. input_dim: int,
  867. inner_dim: int,
  868. num_classes: int,
  869. pooler_dropout: float,
  870. ):
  871. super().__init__()
  872. self.dense = nn.Linear(input_dim, inner_dim)
  873. self.dropout = nn.Dropout(p=pooler_dropout)
  874. self.out_proj = nn.Linear(inner_dim, num_classes)
  875. def forward(self, hidden_states: torch.Tensor):
  876. hidden_states = self.dropout(hidden_states)
  877. hidden_states = self.dense(hidden_states)
  878. hidden_states = torch.tanh(hidden_states)
  879. hidden_states = self.dropout(hidden_states)
  880. hidden_states = self.out_proj(hidden_states)
  881. return hidden_states
  882. @auto_docstring
  883. class LEDPreTrainedModel(PreTrainedModel):
  884. config: LEDConfig
  885. base_model_prefix = "led"
  886. supports_gradient_checkpointing = True
  887. @property
  888. def dummy_inputs(self):
  889. pad_token = self.config.pad_token_id
  890. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  891. dummy_inputs = {
  892. "attention_mask": input_ids.ne(pad_token),
  893. "input_ids": input_ids,
  894. }
  895. return dummy_inputs
  896. def _init_weights(self, module):
  897. super()._init_weights(module)
  898. if isinstance(module, LEDForConditionalGeneration):
  899. init.zeros_(module.final_logits_bias)
  900. @dataclass
  901. @auto_docstring(
  902. custom_intro="""
  903. Base class for LEDEncoder's outputs, with potential hidden states, local and global attentions.
  904. """
  905. )
  906. # Copied from transformers.models.longformer.modeling_longformer.LongformerBaseModelOutput with Longformer->LEDEncoder
  907. class LEDEncoderBaseModelOutput(ModelOutput):
  908. r"""
  909. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  910. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
  911. attention_window + 1)`, where `x` is the number of tokens with global attention mask.
  912. Local attentions weights after the attention softmax, used to compute the weighted average in the
  913. self-attention heads. Those are the attention weights from every token in the sequence to every token with
  914. global attention (first `x` values) and to every token in the attention window (remaining `attention_window
  915. + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
  916. remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
  917. token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
  918. (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
  919. If the attention window contains a token with global attention, the attention weight at the corresponding
  920. index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
  921. attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
  922. accessed from `global_attentions`.
  923. global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  924. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
  925. where `x` is the number of tokens with global attention mask.
  926. Global attentions weights after the attention softmax, used to compute the weighted average in the
  927. self-attention heads. Those are the attention weights from every token with global attention to every token
  928. in the sequence.
  929. """
  930. last_hidden_state: torch.FloatTensor
  931. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  932. attentions: tuple[torch.FloatTensor, ...] | None = None
  933. global_attentions: tuple[torch.FloatTensor, ...] | None = None
  934. @dataclass
  935. @auto_docstring(
  936. custom_intro="""
  937. Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
  938. decoding.
  939. """
  940. )
  941. class LEDSeq2SeqModelOutput(ModelOutput):
  942. r"""
  943. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  944. Sequence of hidden-states at the output of the last layer of the decoder of the model.
  945. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  946. hidden_size)` is output.
  947. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  948. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  949. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  950. used (see `past_key_values` input) to speed up sequential decoding.
  951. encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  952. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
  953. where `x` is the number of tokens with global attention mask.
  954. Global attentions weights after the attention softmax, used to compute the weighted average in the
  955. self-attention heads. Those are the attention weights from every token with global attention to every token
  956. in the sequence.
  957. """
  958. last_hidden_state: torch.FloatTensor | None = None
  959. past_key_values: Cache | None = None
  960. decoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  961. decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  962. cross_attentions: tuple[torch.FloatTensor, ...] | None = None
  963. encoder_last_hidden_state: torch.FloatTensor | None = None
  964. encoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  965. encoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  966. encoder_global_attentions: tuple[torch.FloatTensor, ...] | None = None
  967. @dataclass
  968. @auto_docstring(
  969. custom_intro="""
  970. Base class for sequence-to-sequence language models outputs.
  971. """
  972. )
  973. class LEDSeq2SeqLMOutput(ModelOutput):
  974. r"""
  975. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  976. Language modeling loss.
  977. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  978. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  979. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  980. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  981. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  982. used (see `past_key_values` input) to speed up sequential decoding.
  983. encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  984. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
  985. where `x` is the number of tokens with global attention mask.
  986. Global attentions weights after the attention softmax, used to compute the weighted average in the
  987. self-attention heads. Those are the attention weights from every token with global attention to every token
  988. in the sequence.
  989. """
  990. loss: torch.FloatTensor | None = None
  991. logits: torch.FloatTensor | None = None
  992. past_key_values: Cache | None = None
  993. decoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  994. decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  995. cross_attentions: tuple[torch.FloatTensor, ...] | None = None
  996. encoder_last_hidden_state: torch.FloatTensor | None = None
  997. encoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  998. encoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  999. encoder_global_attentions: tuple[torch.FloatTensor, ...] | None = None
  1000. @dataclass
  1001. @auto_docstring(
  1002. custom_intro="""
  1003. Base class for outputs of sequence-to-sequence sentence classification models.
  1004. """
  1005. )
  1006. class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
  1007. r"""
  1008. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided):
  1009. Classification (or regression if config.num_labels==1) loss.
  1010. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  1011. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  1012. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1013. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  1014. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  1015. used (see `past_key_values` input) to speed up sequential decoding.
  1016. encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1017. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
  1018. where `x` is the number of tokens with global attention mask.
  1019. Global attentions weights after the attention softmax, used to compute the weighted average in the
  1020. self-attention heads. Those are the attention weights from every token with global attention to every token
  1021. in the sequence.
  1022. """
  1023. loss: torch.FloatTensor | None = None
  1024. logits: torch.FloatTensor | None = None
  1025. past_key_values: Cache | None = None
  1026. decoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  1027. decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  1028. cross_attentions: tuple[torch.FloatTensor, ...] | None = None
  1029. encoder_last_hidden_state: torch.FloatTensor | None = None
  1030. encoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  1031. encoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  1032. encoder_global_attentions: tuple[torch.FloatTensor, ...] | None = None
  1033. @dataclass
  1034. @auto_docstring(
  1035. custom_intro="""
  1036. Base class for outputs of sequence-to-sequence question answering models.
  1037. """
  1038. )
  1039. class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
  1040. r"""
  1041. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  1042. Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
  1043. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1044. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  1045. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  1046. used (see `past_key_values` input) to speed up sequential decoding.
  1047. encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1048. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
  1049. where `x` is the number of tokens with global attention mask.
  1050. Global attentions weights after the attention softmax, used to compute the weighted average in the
  1051. self-attention heads. Those are the attention weights from every token with global attention to every token
  1052. in the sequence.
  1053. """
  1054. loss: torch.FloatTensor | None = None
  1055. start_logits: torch.FloatTensor | None = None
  1056. end_logits: torch.FloatTensor | None = None
  1057. past_key_values: Cache | None = None
  1058. decoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  1059. decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  1060. cross_attentions: tuple[torch.FloatTensor, ...] | None = None
  1061. encoder_last_hidden_state: torch.FloatTensor | None = None
  1062. encoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  1063. encoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  1064. encoder_global_attentions: tuple[torch.FloatTensor, ...] | None = None
  1065. class LEDEncoder(LEDPreTrainedModel):
  1066. """
  1067. Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a
  1068. [`LEDEncoderLayer`].
  1069. Args:
  1070. config: LEDConfig
  1071. embed_tokens (nn.Embedding): output embedding
  1072. """
  1073. def __init__(self, config: LEDConfig):
  1074. super().__init__(config)
  1075. self.dropout = config.dropout
  1076. self.layerdrop = config.encoder_layerdrop
  1077. embed_dim = config.d_model
  1078. self.padding_idx = config.pad_token_id
  1079. self.max_source_positions = config.max_encoder_position_embeddings
  1080. if isinstance(config.attention_window, int):
  1081. if config.attention_window % 2 != 0:
  1082. raise ValueError("`config.attention_window` has to be an even value")
  1083. if config.attention_window <= 0:
  1084. raise ValueError("`config.attention_window` has to be positive")
  1085. config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer
  1086. else:
  1087. if len(config.attention_window) != config.num_hidden_layers:
  1088. raise ValueError(
  1089. "`len(config.attention_window)` should equal `config.num_hidden_layers`. "
  1090. f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
  1091. )
  1092. self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
  1093. self.embed_positions = LEDLearnedPositionalEmbedding(
  1094. self.max_source_positions,
  1095. embed_dim,
  1096. )
  1097. self.layers = nn.ModuleList([LEDEncoderLayer(config, i) for i in range(config.encoder_layers)])
  1098. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  1099. self.gradient_checkpointing = False
  1100. # Initialize weights and apply final processing
  1101. self.post_init()
  1102. def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
  1103. # longformer self-attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
  1104. # (global_attention_mask + 1) => 1 for local attention, 2 for global attention
  1105. # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
  1106. if attention_mask is not None:
  1107. attention_mask = attention_mask * (global_attention_mask + 1)
  1108. else:
  1109. # simply use `global_attention_mask` as `attention_mask`
  1110. # if no `attention_mask` is given
  1111. attention_mask = global_attention_mask + 1
  1112. return attention_mask
  1113. def _pad_to_window_size(
  1114. self,
  1115. input_ids: torch.Tensor,
  1116. attention_mask: torch.Tensor,
  1117. inputs_embeds: torch.Tensor,
  1118. pad_token_id: int,
  1119. ):
  1120. """A helper function to pad tokens and mask to work with implementation of Longformer self-attention."""
  1121. # padding
  1122. attention_window = (
  1123. self.config.attention_window
  1124. if isinstance(self.config.attention_window, int)
  1125. else max(self.config.attention_window)
  1126. )
  1127. if attention_window % 2 != 0:
  1128. raise ValueError(f"`attention_window` should be an even value. Given {attention_window}")
  1129. input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
  1130. batch_size, seq_len = input_shape[:2]
  1131. padding_len = (attention_window - seq_len % attention_window) % attention_window
  1132. if padding_len > 0:
  1133. logger.warning_once(
  1134. f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of "
  1135. f"`config.attention_window`: {attention_window}"
  1136. )
  1137. if input_ids is not None:
  1138. input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id)
  1139. if inputs_embeds is not None:
  1140. input_ids_padding = inputs_embeds.new_full(
  1141. (batch_size, padding_len),
  1142. self.config.pad_token_id,
  1143. dtype=torch.long,
  1144. )
  1145. inputs_embeds_padding = self.embed_tokens(input_ids_padding)
  1146. inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
  1147. attention_mask = nn.functional.pad(
  1148. attention_mask, (0, padding_len), value=False
  1149. ) # no attention on the padding tokens
  1150. return padding_len, input_ids, attention_mask, inputs_embeds
  1151. def forward(
  1152. self,
  1153. input_ids=None,
  1154. attention_mask=None,
  1155. global_attention_mask=None,
  1156. inputs_embeds=None,
  1157. output_attentions=None,
  1158. output_hidden_states=None,
  1159. return_dict=None,
  1160. **kwargs,
  1161. ):
  1162. r"""
  1163. Args:
  1164. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1165. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  1166. provide it.
  1167. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1168. [`PreTrainedTokenizer.__call__`] for details.
  1169. [What are input IDs?](../glossary#input-ids)
  1170. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1171. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1172. - 1 for tokens that are **not masked**,
  1173. - 0 for tokens that are **masked**.
  1174. [What are attention masks?](../glossary#attention-mask)
  1175. global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1176. Mask to decide the attention given on each token, local attention or global attention for the encoder.
  1177. Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is
  1178. important for task-specific finetuning because it makes the model more flexible at representing the
  1179. task. For example, for classification, the <s> token should be given global attention. For QA, all
  1180. question tokens should also have global attention. Please refer to the [Longformer
  1181. paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`:
  1182. - 0 for local attention (a sliding window attention),
  1183. - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
  1184. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1185. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  1186. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  1187. than the model's internal embedding lookup matrix.
  1188. output_attentions (`bool`, *optional*):
  1189. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1190. returned tensors for more detail.
  1191. output_hidden_states (`bool`, *optional*):
  1192. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1193. for more detail.
  1194. return_dict (`bool`, *optional*):
  1195. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1196. """
  1197. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1198. output_hidden_states = (
  1199. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1200. )
  1201. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1202. # check input_ids and inputs_embeds
  1203. if input_ids is not None and inputs_embeds is not None:
  1204. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  1205. elif input_ids is None and inputs_embeds is None:
  1206. raise ValueError("You have to specify either input_ids or inputs_embeds")
  1207. if inputs_embeds is None:
  1208. inputs_embeds = self.embed_tokens(input_ids)
  1209. # create default attention_mask
  1210. if attention_mask is None:
  1211. attention_mask = torch.ones(inputs_embeds.size()[:-1], device=inputs_embeds.device, dtype=torch.long)
  1212. # merge `global_attention_mask` and `attention_mask`
  1213. if global_attention_mask is not None:
  1214. attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
  1215. # pad input if necessary
  1216. padding_len, input_ids, attention_mask, inputs_embeds = self._pad_to_window_size(
  1217. input_ids=input_ids,
  1218. attention_mask=attention_mask,
  1219. inputs_embeds=inputs_embeds,
  1220. pad_token_id=self.config.pad_token_id,
  1221. )
  1222. # retrieve input_shape
  1223. if input_ids is not None:
  1224. input_shape = input_ids.size()
  1225. input_ids = input_ids.view(-1, input_shape[-1])
  1226. elif inputs_embeds is not None:
  1227. input_shape = inputs_embeds.size()[:-1]
  1228. # convert attention_mask to float
  1229. if attention_mask is not None:
  1230. # [bsz, seq_len] -> [bsz, seq_len]; 1 -> 0.0; 0 -> "-inf"
  1231. attention_mask = _prepare_4d_attention_mask_inverted(attention_mask, inputs_embeds.dtype)[:, 0, 0, :]
  1232. # get masking tensors
  1233. is_index_masked = attention_mask < 0
  1234. is_index_global_attn = attention_mask > 0
  1235. is_global_attn = is_index_global_attn.flatten().any().item()
  1236. embed_pos = self.embed_positions(input_shape)
  1237. hidden_states = inputs_embeds + embed_pos
  1238. hidden_states = self.layernorm_embedding(hidden_states)
  1239. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  1240. encoder_states = () if output_hidden_states else None
  1241. all_attentions = () if output_attentions else None
  1242. all_global_attentions = () if (output_attentions and is_global_attn) else None
  1243. for idx, encoder_layer in enumerate(self.layers):
  1244. if output_hidden_states:
  1245. encoder_states = encoder_states + (hidden_states,)
  1246. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  1247. dropout_probability = torch.rand([])
  1248. if self.training and (dropout_probability < self.layerdrop): # skip the layer
  1249. layer_outputs = (None, None, None)
  1250. else:
  1251. layer_outputs = encoder_layer(
  1252. hidden_states,
  1253. attention_mask=attention_mask,
  1254. is_index_masked=is_index_masked,
  1255. is_index_global_attn=is_index_global_attn,
  1256. is_global_attn=is_global_attn,
  1257. output_attentions=output_attentions,
  1258. )
  1259. hidden_states = layer_outputs[0]
  1260. if output_attentions:
  1261. # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
  1262. all_attentions = all_attentions + (layer_outputs[1].transpose(1, 2),)
  1263. if is_global_attn:
  1264. # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
  1265. all_global_attentions = all_global_attentions + (layer_outputs[2].transpose(2, 3),)
  1266. if output_hidden_states:
  1267. encoder_states = encoder_states + (hidden_states,)
  1268. # undo padding
  1269. if padding_len > 0:
  1270. # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
  1271. hidden_states = hidden_states[:, :-padding_len]
  1272. if output_hidden_states:
  1273. encoder_states = tuple(state[:, :-padding_len] for state in encoder_states)
  1274. if output_attentions:
  1275. all_attentions = tuple(state[:, :, :-padding_len, :] for state in all_attentions)
  1276. if not return_dict:
  1277. return tuple(
  1278. v for v in [hidden_states, encoder_states, all_attentions, all_global_attentions] if v is not None
  1279. )
  1280. return LEDEncoderBaseModelOutput(
  1281. last_hidden_state=hidden_states,
  1282. hidden_states=encoder_states,
  1283. attentions=all_attentions,
  1284. global_attentions=all_global_attentions,
  1285. )
  1286. class LEDDecoder(LEDPreTrainedModel):
  1287. """
  1288. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`LEDDecoderLayer`]
  1289. Args:
  1290. config: LEDConfig
  1291. embed_tokens (nn.Embedding): output embedding
  1292. """
  1293. def __init__(self, config: LEDConfig):
  1294. super().__init__(config)
  1295. self.dropout = config.dropout
  1296. self.layerdrop = config.decoder_layerdrop
  1297. self.padding_idx = config.pad_token_id
  1298. self.max_target_positions = config.max_decoder_position_embeddings
  1299. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  1300. self.embed_positions = LEDLearnedPositionalEmbedding(
  1301. self.max_target_positions,
  1302. config.d_model,
  1303. )
  1304. self.layers = nn.ModuleList([LEDDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  1305. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  1306. self.gradient_checkpointing = False
  1307. # Initialize weights and apply final processing
  1308. self.post_init()
  1309. def forward(
  1310. self,
  1311. input_ids=None,
  1312. attention_mask=None,
  1313. global_attention_mask=None,
  1314. encoder_hidden_states=None,
  1315. encoder_attention_mask=None,
  1316. past_key_values=None,
  1317. inputs_embeds=None,
  1318. use_cache=None,
  1319. output_attentions=None,
  1320. output_hidden_states=None,
  1321. return_dict=None,
  1322. **kwargs,
  1323. ):
  1324. r"""
  1325. Args:
  1326. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1327. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  1328. provide it.
  1329. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1330. [`PreTrainedTokenizer.__call__`] for details.
  1331. [What are input IDs?](../glossary#input-ids)
  1332. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1333. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1334. - 1 for tokens that are **not masked**,
  1335. - 0 for tokens that are **masked**.
  1336. [What are attention masks?](../glossary#attention-mask)
  1337. global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1338. Mask to decide the attention given on each token, local attention or global attention. Tokens with
  1339. global attention attends to all other tokens, and all other tokens attend to them. This is important
  1340. for task-specific finetuning because it makes the model more flexible at representing the task. For
  1341. example, for classification, the <s> token should be given global attention. For QA, all question
  1342. tokens should also have global attention. Please refer to the [Longformer
  1343. paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`:
  1344. - 0 for local attention (a sliding window attention),
  1345. - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
  1346. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  1347. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1348. of the decoder.
  1349. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  1350. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  1351. selected in `[0, 1]`:
  1352. - 1 for tokens that are **not masked**,
  1353. - 0 for tokens that are **masked**.
  1354. [What are attention masks?](../glossary#attention-mask)
  1355. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1356. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  1357. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  1358. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  1359. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  1360. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  1361. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1362. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1363. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  1364. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  1365. than the model's internal embedding lookup matrix.
  1366. output_attentions (`bool`, *optional*):
  1367. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1368. returned tensors for more detail.
  1369. output_hidden_states (`bool`, *optional*):
  1370. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1371. for more detail.
  1372. return_dict (`bool`, *optional*):
  1373. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1374. """
  1375. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1376. output_hidden_states = (
  1377. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1378. )
  1379. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1380. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1381. # retrieve input_ids and inputs_embeds
  1382. if input_ids is not None and inputs_embeds is not None:
  1383. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  1384. elif input_ids is not None:
  1385. input_shape = input_ids.size()
  1386. input_ids = input_ids.view(-1, input_shape[-1])
  1387. elif inputs_embeds is not None:
  1388. input_shape = inputs_embeds.size()[:-1]
  1389. else:
  1390. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  1391. if inputs_embeds is None:
  1392. inputs_embeds = self.embed_tokens(input_ids)
  1393. if self.gradient_checkpointing and self.training:
  1394. if use_cache:
  1395. logger.warning_once(
  1396. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1397. )
  1398. use_cache = False
  1399. if use_cache and past_key_values is None:
  1400. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  1401. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  1402. combined_attention_mask = None
  1403. if input_shape[-1] > 1: # only create a causal mask when we go over a single token
  1404. combined_attention_mask = create_causal_mask(
  1405. config=self.config,
  1406. inputs_embeds=inputs_embeds,
  1407. attention_mask=attention_mask,
  1408. past_key_values=past_key_values,
  1409. )
  1410. encoder_attention_mask = create_bidirectional_mask(
  1411. config=self.config,
  1412. inputs_embeds=inputs_embeds,
  1413. attention_mask=encoder_attention_mask,
  1414. encoder_hidden_states=encoder_hidden_states,
  1415. )
  1416. # embed positions
  1417. positions = self.embed_positions(input_shape, past_key_values_length)
  1418. hidden_states = inputs_embeds + positions
  1419. hidden_states = self.layernorm_embedding(hidden_states)
  1420. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  1421. # decoder layers
  1422. all_hidden_states = () if output_hidden_states else None
  1423. all_self_attns = () if output_attentions else None
  1424. all_cross_attentions = () if output_attentions else None
  1425. for idx, decoder_layer in enumerate(self.layers):
  1426. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  1427. if output_hidden_states:
  1428. all_hidden_states += (hidden_states,)
  1429. if self.training:
  1430. dropout_probability = torch.rand([])
  1431. if dropout_probability < self.layerdrop:
  1432. continue
  1433. layer_outputs = decoder_layer(
  1434. hidden_states,
  1435. combined_attention_mask,
  1436. encoder_hidden_states, # as a positional argument for gradient checkpointing
  1437. encoder_attention_mask=encoder_attention_mask,
  1438. past_key_values=past_key_values,
  1439. output_attentions=output_attentions,
  1440. use_cache=use_cache,
  1441. )
  1442. hidden_states = layer_outputs[0]
  1443. if output_attentions:
  1444. all_self_attns += (layer_outputs[1],)
  1445. all_cross_attentions += (layer_outputs[2],)
  1446. # add hidden states from the last decoder layer
  1447. if output_hidden_states:
  1448. all_hidden_states += (hidden_states,)
  1449. if not return_dict:
  1450. return tuple(
  1451. v
  1452. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  1453. if v is not None
  1454. )
  1455. return BaseModelOutputWithPastAndCrossAttentions(
  1456. last_hidden_state=hidden_states,
  1457. past_key_values=past_key_values,
  1458. hidden_states=all_hidden_states,
  1459. attentions=all_self_attns,
  1460. cross_attentions=all_cross_attentions,
  1461. )
  1462. @auto_docstring
  1463. class LEDModel(LEDPreTrainedModel):
  1464. _tied_weights_keys = {
  1465. "encoder.embed_tokens.weight": "shared.weight",
  1466. "decoder.embed_tokens.weight": "shared.weight",
  1467. }
  1468. def __init__(self, config: LEDConfig):
  1469. super().__init__(config)
  1470. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  1471. self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  1472. self.encoder = LEDEncoder(config)
  1473. self.decoder = LEDDecoder(config)
  1474. # Initialize weights and apply final processing
  1475. self.post_init()
  1476. def get_input_embeddings(self):
  1477. return self.shared
  1478. def set_input_embeddings(self, value):
  1479. self.shared = value
  1480. self.encoder.embed_tokens = self.shared
  1481. self.decoder.embed_tokens = self.shared
  1482. @auto_docstring
  1483. def forward(
  1484. self,
  1485. input_ids: torch.LongTensor | None = None,
  1486. attention_mask: torch.Tensor | None = None,
  1487. decoder_input_ids: torch.LongTensor | None = None,
  1488. decoder_attention_mask: torch.LongTensor | None = None,
  1489. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  1490. global_attention_mask: torch.FloatTensor | None = None,
  1491. past_key_values: Cache | None = None,
  1492. inputs_embeds: torch.FloatTensor | None = None,
  1493. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1494. use_cache: bool | None = None,
  1495. output_attentions: bool | None = None,
  1496. output_hidden_states: bool | None = None,
  1497. return_dict: bool | None = None,
  1498. **kwargs,
  1499. ) -> tuple[torch.Tensor] | LEDSeq2SeqModelOutput:
  1500. r"""
  1501. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1502. Indices of decoder input sequence tokens in the vocabulary.
  1503. Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1504. [`PreTrainedTokenizer.__call__`] for details.
  1505. [What are input IDs?](../glossary#input-ids)
  1506. LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1507. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1508. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1509. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1510. be used by default.
  1511. If you want to change padding behavior, you should read [`modeling_led._prepare_decoder_inputs`] and modify
  1512. to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more information on the
  1513. default strategy.
  1514. global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1515. Mask to decide the attention given on each token, local attention or global attention for the encoder.
  1516. Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is
  1517. important for task-specific finetuning because it makes the model more flexible at representing the task.
  1518. For example, for classification, the <s> token should be given global attention. For QA, all question
  1519. tokens should also have global attention. Please refer to the [Longformer
  1520. paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`:
  1521. - 0 for local attention (a sliding window attention),
  1522. - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
  1523. """
  1524. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1525. output_hidden_states = (
  1526. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1527. )
  1528. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1529. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1530. # Using this like Bart, as LED is derived from it. So far
  1531. # No checkpoint on the hub exists that uses that in practice.
  1532. # https://github.com/huggingface/transformers/blob/ac3cb660cad283163f7c73cad511124e845ca388/src/transformers/models/bart/modeling_bart.py#L1153
  1533. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1534. decoder_input_ids = shift_tokens_right(
  1535. input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
  1536. )
  1537. if encoder_outputs is None:
  1538. encoder_outputs = self.encoder(
  1539. input_ids=input_ids,
  1540. attention_mask=attention_mask,
  1541. global_attention_mask=global_attention_mask,
  1542. inputs_embeds=inputs_embeds,
  1543. output_attentions=output_attentions,
  1544. output_hidden_states=output_hidden_states,
  1545. return_dict=return_dict,
  1546. )
  1547. # If the user passed a tuple for encoder_outputs, we wrap it in a LEDEncoderBaseModelOutput when return_dict=False
  1548. elif return_dict and not isinstance(encoder_outputs, LEDEncoderBaseModelOutput):
  1549. encoder_outputs = LEDEncoderBaseModelOutput(
  1550. last_hidden_state=encoder_outputs[0],
  1551. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1552. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1553. global_attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
  1554. )
  1555. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1556. decoder_outputs = self.decoder(
  1557. input_ids=decoder_input_ids,
  1558. attention_mask=decoder_attention_mask,
  1559. encoder_hidden_states=encoder_outputs[0],
  1560. encoder_attention_mask=attention_mask,
  1561. past_key_values=past_key_values,
  1562. inputs_embeds=decoder_inputs_embeds,
  1563. use_cache=use_cache,
  1564. output_attentions=output_attentions,
  1565. output_hidden_states=output_hidden_states,
  1566. return_dict=return_dict,
  1567. )
  1568. if not return_dict:
  1569. return decoder_outputs + encoder_outputs
  1570. return LEDSeq2SeqModelOutput(
  1571. last_hidden_state=decoder_outputs.last_hidden_state,
  1572. past_key_values=decoder_outputs.past_key_values,
  1573. decoder_hidden_states=decoder_outputs.hidden_states,
  1574. decoder_attentions=decoder_outputs.attentions,
  1575. cross_attentions=decoder_outputs.cross_attentions,
  1576. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1577. encoder_hidden_states=encoder_outputs.hidden_states,
  1578. encoder_attentions=encoder_outputs.attentions,
  1579. encoder_global_attentions=encoder_outputs.global_attentions,
  1580. )
  1581. @auto_docstring(
  1582. custom_intro="""
  1583. The LED Model with a language modeling head. Can be used for summarization.
  1584. """
  1585. )
  1586. class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin):
  1587. base_model_prefix = "led"
  1588. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  1589. _tied_weights_keys = {
  1590. "lm_head.weight": "led.shared.weight",
  1591. }
  1592. def __init__(self, config: LEDConfig):
  1593. super().__init__(config)
  1594. self.led = LEDModel(config)
  1595. self.register_buffer("final_logits_bias", torch.zeros((1, self.led.shared.num_embeddings)))
  1596. self.lm_head = nn.Linear(config.d_model, self.led.shared.num_embeddings, bias=False)
  1597. # Initialize weights and apply final processing
  1598. self.post_init()
  1599. def resize_token_embeddings(
  1600. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  1601. ) -> nn.Embedding:
  1602. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  1603. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  1604. return new_embeddings
  1605. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1606. old_num_tokens = self.final_logits_bias.shape[-1]
  1607. if new_num_tokens <= old_num_tokens:
  1608. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1609. else:
  1610. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1611. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1612. self.register_buffer("final_logits_bias", new_bias)
  1613. @auto_docstring
  1614. def forward(
  1615. self,
  1616. input_ids: torch.LongTensor | None = None,
  1617. attention_mask: torch.Tensor | None = None,
  1618. decoder_input_ids: torch.LongTensor | None = None,
  1619. decoder_attention_mask: torch.LongTensor | None = None,
  1620. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  1621. global_attention_mask: torch.FloatTensor | None = None,
  1622. past_key_values: Cache | None = None,
  1623. inputs_embeds: torch.FloatTensor | None = None,
  1624. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1625. labels: torch.LongTensor | None = None,
  1626. use_cache: bool | None = None,
  1627. output_attentions: bool | None = None,
  1628. output_hidden_states: bool | None = None,
  1629. return_dict: bool | None = None,
  1630. **kwargs,
  1631. ) -> tuple[torch.Tensor] | LEDSeq2SeqLMOutput:
  1632. r"""
  1633. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1634. Indices of decoder input sequence tokens in the vocabulary.
  1635. Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1636. [`PreTrainedTokenizer.__call__`] for details.
  1637. [What are input IDs?](../glossary#input-ids)
  1638. LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1639. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1640. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1641. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1642. be used by default.
  1643. If you want to change padding behavior, you should read [`modeling_led._prepare_decoder_inputs`] and modify
  1644. to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more information on the
  1645. default strategy.
  1646. global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1647. Mask to decide the attention given on each token, local attention or global attention for the encoder.
  1648. Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is
  1649. important for task-specific finetuning because it makes the model more flexible at representing the task.
  1650. For example, for classification, the <s> token should be given global attention. For QA, all question
  1651. tokens should also have global attention. Please refer to the [Longformer
  1652. paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`:
  1653. - 0 for local attention (a sliding window attention),
  1654. - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
  1655. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1656. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1657. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1658. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1659. Example Summarization:
  1660. ```python
  1661. >>> import torch
  1662. >>> from transformers import AutoTokenizer, LEDForConditionalGeneration
  1663. >>> model = LEDForConditionalGeneration.from_pretrained("allenai/led-large-16384-arxiv")
  1664. >>> tokenizer = AutoTokenizer.from_pretrained("allenai/led-large-16384-arxiv")
  1665. >>> ARTICLE_TO_SUMMARIZE = '''Transformers (Vaswani et al., 2017) have achieved state-of-the-art
  1666. ... results in a wide range of natural language tasks including generative language modeling
  1667. ... (Dai et al., 2019; Radford et al., 2019) and discriminative ... language understanding (Devlin et al., 2019).
  1668. ... This success is partly due to the self-attention component which enables the network to capture contextual
  1669. ... information from the entire sequence. While powerful, the memory and computational requirements of
  1670. ... self-attention grow quadratically with sequence length, making it infeasible (or very expensive) to
  1671. ... process long sequences. To address this limitation, we present Longformer, a modified Transformer
  1672. ... architecture with a self-attention operation that scales linearly with the sequence length, making it
  1673. ... versatile for processing long documents (Fig 1). This is an advantage for natural language tasks such as
  1674. ... long document classification, question answering (QA), and coreference resolution, where existing approaches
  1675. ... partition or shorten the long context into smaller sequences that fall within the typical 512 token limit
  1676. ... of BERT-style pretrained models. Such partitioning could potentially result in loss of important
  1677. ... cross-partition information, and to mitigate this problem, existing methods often rely on complex
  1678. ... architectures to address such interactions. On the other hand, our proposed Longformer is able to build
  1679. ... contextual representations of the entire context using multiple layers of attention, reducing the need for
  1680. ... task-specific architectures.'''
  1681. >>> inputs = tokenizer.encode(ARTICLE_TO_SUMMARIZE, return_tensors="pt")
  1682. >>> # Global attention on the first token (cf. Beltagy et al. 2020)
  1683. >>> global_attention_mask = torch.zeros_like(inputs)
  1684. >>> global_attention_mask[:, 0] = 1
  1685. >>> # Generate Summary
  1686. >>> summary_ids = model.generate(inputs, global_attention_mask=global_attention_mask, num_beams=3, max_length=32)
  1687. >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
  1688. ```
  1689. Example Conditional generation :
  1690. ```python
  1691. >>> from transformers import AutoTokenizer, LEDForConditionalGeneration
  1692. >>> tokenizer = AutoTokenizer.from_pretrained("allenai/led-base-16384")
  1693. >>> TXT = "My friends are <mask> but they eat too many carbs."
  1694. >>> model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
  1695. >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
  1696. >>> prediction = model.generate(input_ids)[0]
  1697. >>> print(tokenizer.decode(prediction, skip_special_tokens=True))
  1698. ```
  1699. """
  1700. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1701. if labels is not None:
  1702. if use_cache:
  1703. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1704. use_cache = False
  1705. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1706. decoder_input_ids = shift_tokens_right(
  1707. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1708. )
  1709. outputs = self.led(
  1710. input_ids,
  1711. attention_mask=attention_mask,
  1712. decoder_input_ids=decoder_input_ids,
  1713. decoder_attention_mask=decoder_attention_mask,
  1714. encoder_outputs=encoder_outputs,
  1715. global_attention_mask=global_attention_mask,
  1716. past_key_values=past_key_values,
  1717. inputs_embeds=inputs_embeds,
  1718. decoder_inputs_embeds=decoder_inputs_embeds,
  1719. use_cache=use_cache,
  1720. output_attentions=output_attentions,
  1721. output_hidden_states=output_hidden_states,
  1722. return_dict=return_dict,
  1723. )
  1724. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  1725. masked_lm_loss = None
  1726. if labels is not None:
  1727. loss_fct = CrossEntropyLoss()
  1728. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1729. if not return_dict:
  1730. output = (lm_logits,) + outputs[1:]
  1731. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1732. return LEDSeq2SeqLMOutput(
  1733. loss=masked_lm_loss,
  1734. logits=lm_logits,
  1735. past_key_values=outputs.past_key_values,
  1736. decoder_hidden_states=outputs.decoder_hidden_states,
  1737. decoder_attentions=outputs.decoder_attentions,
  1738. cross_attentions=outputs.cross_attentions,
  1739. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1740. encoder_hidden_states=outputs.encoder_hidden_states,
  1741. encoder_attentions=outputs.encoder_attentions,
  1742. encoder_global_attentions=outputs.encoder_global_attentions,
  1743. )
  1744. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1745. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  1746. @auto_docstring
  1747. class LEDForQuestionAnswering(LEDPreTrainedModel):
  1748. def __init__(self, config):
  1749. super().__init__(config)
  1750. config.num_labels = 2
  1751. self.num_labels = config.num_labels
  1752. self.led = LEDModel(config)
  1753. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1754. # Initialize weights and apply final processing
  1755. self.post_init()
  1756. @auto_docstring
  1757. def forward(
  1758. self,
  1759. input_ids: torch.LongTensor | None = None,
  1760. attention_mask: torch.Tensor | None = None,
  1761. decoder_input_ids: torch.LongTensor | None = None,
  1762. decoder_attention_mask: torch.LongTensor | None = None,
  1763. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  1764. global_attention_mask: torch.FloatTensor | None = None,
  1765. start_positions: torch.LongTensor | None = None,
  1766. end_positions: torch.LongTensor | None = None,
  1767. inputs_embeds: torch.FloatTensor | None = None,
  1768. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1769. use_cache: bool | None = None,
  1770. output_attentions: bool | None = None,
  1771. output_hidden_states: bool | None = None,
  1772. return_dict: bool | None = None,
  1773. **kwargs,
  1774. ) -> tuple[torch.Tensor] | LEDSeq2SeqQuestionAnsweringModelOutput:
  1775. r"""
  1776. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1777. Indices of decoder input sequence tokens in the vocabulary.
  1778. Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1779. [`PreTrainedTokenizer.__call__`] for details.
  1780. [What are input IDs?](../glossary#input-ids)
  1781. LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1782. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1783. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1784. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1785. be used by default.
  1786. If you want to change padding behavior, you should read [`modeling_led._prepare_decoder_inputs`] and modify
  1787. to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more information on the
  1788. default strategy.
  1789. global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1790. Mask to decide the attention given on each token, local attention or global attention for the encoder.
  1791. Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is
  1792. important for task-specific finetuning because it makes the model more flexible at representing the task.
  1793. For example, for classification, the <s> token should be given global attention. For QA, all question
  1794. tokens should also have global attention. Please refer to the [Longformer
  1795. paper](https://huggingface.co/papers/2004.05150) for more details. Mask values selected in `[0, 1]`:
  1796. - 0 for local attention (a sliding window attention),
  1797. - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
  1798. """
  1799. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1800. if start_positions is not None and end_positions is not None:
  1801. use_cache = False
  1802. outputs = self.led(
  1803. input_ids,
  1804. attention_mask=attention_mask,
  1805. decoder_input_ids=decoder_input_ids,
  1806. decoder_attention_mask=decoder_attention_mask,
  1807. global_attention_mask=global_attention_mask,
  1808. encoder_outputs=encoder_outputs,
  1809. inputs_embeds=inputs_embeds,
  1810. decoder_inputs_embeds=decoder_inputs_embeds,
  1811. use_cache=use_cache,
  1812. output_attentions=output_attentions,
  1813. output_hidden_states=output_hidden_states,
  1814. return_dict=return_dict,
  1815. )
  1816. sequence_output = outputs[0]
  1817. logits = self.qa_outputs(sequence_output)
  1818. start_logits, end_logits = logits.split(1, dim=-1)
  1819. start_logits = start_logits.squeeze(-1).contiguous()
  1820. end_logits = end_logits.squeeze(-1).contiguous()
  1821. total_loss = None
  1822. if start_positions is not None and end_positions is not None:
  1823. # If we are on multi-GPU, split add a dimension
  1824. if len(start_positions.size()) > 1:
  1825. start_positions = start_positions.squeeze(-1)
  1826. if len(end_positions.size()) > 1:
  1827. end_positions = end_positions.squeeze(-1)
  1828. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1829. ignored_index = start_logits.size(1)
  1830. start_positions = start_positions.clamp(0, ignored_index)
  1831. end_positions = end_positions.clamp(0, ignored_index)
  1832. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1833. start_loss = loss_fct(start_logits, start_positions)
  1834. end_loss = loss_fct(end_logits, end_positions)
  1835. total_loss = (start_loss + end_loss) / 2
  1836. if not return_dict:
  1837. output = (
  1838. start_logits,
  1839. end_logits,
  1840. ) + outputs[1:]
  1841. return ((total_loss,) + output) if total_loss is not None else output
  1842. return LEDSeq2SeqQuestionAnsweringModelOutput(
  1843. loss=total_loss,
  1844. start_logits=start_logits,
  1845. end_logits=end_logits,
  1846. past_key_values=outputs.past_key_values,
  1847. decoder_hidden_states=outputs.decoder_hidden_states,
  1848. decoder_attentions=outputs.decoder_attentions,
  1849. cross_attentions=outputs.cross_attentions,
  1850. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1851. encoder_hidden_states=outputs.encoder_hidden_states,
  1852. encoder_attentions=outputs.encoder_attentions,
  1853. encoder_global_attentions=outputs.encoder_global_attentions,
  1854. )
  1855. __all__ = [
  1856. "LEDForConditionalGeneration",
  1857. "LEDForQuestionAnswering",
  1858. "LEDModel",
  1859. "LEDPreTrainedModel",
  1860. ]