modeling_electra.py 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377
  1. # Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ELECTRA model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN, get_activation
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutputWithCrossAttentions,
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. CausalLMOutputWithCrossAttentions,
  30. MaskedLMOutput,
  31. MultipleChoiceModelOutput,
  32. QuestionAnsweringModelOutput,
  33. SequenceClassifierOutput,
  34. TokenClassifierOutput,
  35. )
  36. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  37. from ...processing_utils import Unpack
  38. from ...pytorch_utils import apply_chunking_to_forward
  39. from ...utils import (
  40. ModelOutput,
  41. TransformersKwargs,
  42. auto_docstring,
  43. logging,
  44. )
  45. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  46. from ...utils.output_capturing import capture_outputs
  47. from .configuration_electra import ElectraConfig
  48. logger = logging.get_logger(__name__)
  49. class ElectraEmbeddings(nn.Module):
  50. """Construct the embeddings from word, position and token_type embeddings."""
  51. def __init__(self, config):
  52. super().__init__()
  53. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  54. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
  55. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
  56. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  57. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  58. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  59. self.register_buffer(
  60. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  61. )
  62. self.register_buffer(
  63. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  64. )
  65. # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
  66. def forward(
  67. self,
  68. input_ids: torch.LongTensor | None = None,
  69. token_type_ids: torch.LongTensor | None = None,
  70. position_ids: torch.LongTensor | None = None,
  71. inputs_embeds: torch.FloatTensor | None = None,
  72. past_key_values_length: int = 0,
  73. ) -> torch.Tensor:
  74. if input_ids is not None:
  75. input_shape = input_ids.size()
  76. else:
  77. input_shape = inputs_embeds.size()[:-1]
  78. batch_size, seq_length = input_shape
  79. if position_ids is None:
  80. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  81. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  82. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  83. # issue #5664
  84. if token_type_ids is None:
  85. if hasattr(self, "token_type_ids"):
  86. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  87. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  88. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  89. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  90. else:
  91. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  92. if inputs_embeds is None:
  93. inputs_embeds = self.word_embeddings(input_ids)
  94. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  95. embeddings = inputs_embeds + token_type_embeddings
  96. position_embeddings = self.position_embeddings(position_ids)
  97. embeddings = embeddings + position_embeddings
  98. embeddings = self.LayerNorm(embeddings)
  99. embeddings = self.dropout(embeddings)
  100. return embeddings
  101. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  102. def eager_attention_forward(
  103. module: nn.Module,
  104. query: torch.Tensor,
  105. key: torch.Tensor,
  106. value: torch.Tensor,
  107. attention_mask: torch.Tensor | None,
  108. scaling: float | None = None,
  109. dropout: float = 0.0,
  110. **kwargs: Unpack[TransformersKwargs],
  111. ):
  112. if scaling is None:
  113. scaling = query.size(-1) ** -0.5
  114. # Take the dot product between "query" and "key" to get the raw attention scores.
  115. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  116. if attention_mask is not None:
  117. attn_weights = attn_weights + attention_mask
  118. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  119. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  120. attn_output = torch.matmul(attn_weights, value)
  121. attn_output = attn_output.transpose(1, 2).contiguous()
  122. return attn_output, attn_weights
  123. # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra
  124. class ElectraSelfAttention(nn.Module):
  125. def __init__(self, config, is_causal=False, layer_idx=None):
  126. super().__init__()
  127. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  128. raise ValueError(
  129. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  130. f"heads ({config.num_attention_heads})"
  131. )
  132. self.config = config
  133. self.num_attention_heads = config.num_attention_heads
  134. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  135. self.all_head_size = self.num_attention_heads * self.attention_head_size
  136. self.scaling = self.attention_head_size**-0.5
  137. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  138. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  139. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  140. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  141. self.is_decoder = config.is_decoder
  142. self.is_causal = is_causal
  143. self.layer_idx = layer_idx
  144. def forward(
  145. self,
  146. hidden_states: torch.Tensor,
  147. attention_mask: torch.FloatTensor | None = None,
  148. past_key_values: Cache | None = None,
  149. **kwargs: Unpack[TransformersKwargs],
  150. ) -> tuple[torch.Tensor]:
  151. input_shape = hidden_states.shape[:-1]
  152. hidden_shape = (*input_shape, -1, self.attention_head_size)
  153. # get all proj
  154. query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
  155. key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
  156. value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
  157. if past_key_values is not None:
  158. # decoder-only bert can have a simple dynamic cache for example
  159. current_past_key_values = past_key_values
  160. if isinstance(past_key_values, EncoderDecoderCache):
  161. current_past_key_values = past_key_values.self_attention_cache
  162. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  163. key_layer, value_layer = current_past_key_values.update(key_layer, value_layer, self.layer_idx)
  164. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  165. self.config._attn_implementation, eager_attention_forward
  166. )
  167. attn_output, attn_weights = attention_interface(
  168. self,
  169. query_layer,
  170. key_layer,
  171. value_layer,
  172. attention_mask,
  173. dropout=0.0 if not self.training else self.dropout.p,
  174. scaling=self.scaling,
  175. **kwargs,
  176. )
  177. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  178. return attn_output, attn_weights
  179. # Copied from transformers.models.bert.modeling_bert.BertCrossAttention with Bert->Electra
  180. class ElectraCrossAttention(nn.Module):
  181. def __init__(self, config, is_causal=False, layer_idx=None):
  182. super().__init__()
  183. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  184. raise ValueError(
  185. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  186. f"heads ({config.num_attention_heads})"
  187. )
  188. self.config = config
  189. self.num_attention_heads = config.num_attention_heads
  190. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  191. self.all_head_size = self.num_attention_heads * self.attention_head_size
  192. self.scaling = self.attention_head_size**-0.5
  193. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  194. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  195. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  196. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  197. self.is_causal = is_causal
  198. self.layer_idx = layer_idx
  199. def forward(
  200. self,
  201. hidden_states: torch.Tensor,
  202. encoder_hidden_states: torch.FloatTensor | None = None,
  203. attention_mask: torch.FloatTensor | None = None,
  204. past_key_values: EncoderDecoderCache | None = None,
  205. **kwargs: Unpack[TransformersKwargs],
  206. ) -> tuple[torch.Tensor]:
  207. # determine input shapes
  208. input_shape = hidden_states.shape[:-1]
  209. hidden_shape = (*input_shape, -1, self.attention_head_size)
  210. # get query proj
  211. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  212. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  213. if past_key_values is not None and is_updated:
  214. # reuse k,v, cross_attentions
  215. key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  216. value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  217. else:
  218. kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
  219. key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  220. value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  221. if past_key_values is not None:
  222. # save all states to the cache
  223. key_layer, value_layer = past_key_values.cross_attention_cache.update(
  224. key_layer, value_layer, self.layer_idx
  225. )
  226. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  227. past_key_values.is_updated[self.layer_idx] = True
  228. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  229. self.config._attn_implementation, eager_attention_forward
  230. )
  231. attn_output, attn_weights = attention_interface(
  232. self,
  233. query_layer,
  234. key_layer,
  235. value_layer,
  236. attention_mask,
  237. dropout=0.0 if not self.training else self.dropout.p,
  238. scaling=self.scaling,
  239. **kwargs,
  240. )
  241. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  242. return attn_output, attn_weights
  243. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  244. class ElectraSelfOutput(nn.Module):
  245. def __init__(self, config):
  246. super().__init__()
  247. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  248. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  249. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  250. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  251. hidden_states = self.dense(hidden_states)
  252. hidden_states = self.dropout(hidden_states)
  253. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  254. return hidden_states
  255. # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA
  256. class ElectraAttention(nn.Module):
  257. def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
  258. super().__init__()
  259. self.is_cross_attention = is_cross_attention
  260. attention_class = ElectraCrossAttention if is_cross_attention else ElectraSelfAttention
  261. self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
  262. self.output = ElectraSelfOutput(config)
  263. def forward(
  264. self,
  265. hidden_states: torch.Tensor,
  266. attention_mask: torch.FloatTensor | None = None,
  267. encoder_hidden_states: torch.FloatTensor | None = None,
  268. encoder_attention_mask: torch.FloatTensor | None = None,
  269. past_key_values: Cache | None = None,
  270. **kwargs: Unpack[TransformersKwargs],
  271. ) -> tuple[torch.Tensor]:
  272. attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
  273. attention_output, attn_weights = self.self(
  274. hidden_states,
  275. encoder_hidden_states=encoder_hidden_states,
  276. attention_mask=attention_mask,
  277. past_key_values=past_key_values,
  278. **kwargs,
  279. )
  280. attention_output = self.output(attention_output, hidden_states)
  281. return attention_output, attn_weights
  282. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  283. class ElectraIntermediate(nn.Module):
  284. def __init__(self, config):
  285. super().__init__()
  286. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  287. if isinstance(config.hidden_act, str):
  288. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  289. else:
  290. self.intermediate_act_fn = config.hidden_act
  291. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  292. hidden_states = self.dense(hidden_states)
  293. hidden_states = self.intermediate_act_fn(hidden_states)
  294. return hidden_states
  295. # Copied from transformers.models.bert.modeling_bert.BertOutput
  296. class ElectraOutput(nn.Module):
  297. def __init__(self, config):
  298. super().__init__()
  299. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  300. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  301. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  302. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  303. hidden_states = self.dense(hidden_states)
  304. hidden_states = self.dropout(hidden_states)
  305. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  306. return hidden_states
  307. # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra
  308. class ElectraLayer(GradientCheckpointingLayer):
  309. def __init__(self, config, layer_idx=None):
  310. super().__init__()
  311. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  312. self.seq_len_dim = 1
  313. self.attention = ElectraAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
  314. self.is_decoder = config.is_decoder
  315. self.add_cross_attention = config.add_cross_attention
  316. if self.add_cross_attention:
  317. if not self.is_decoder:
  318. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  319. self.crossattention = ElectraAttention(
  320. config,
  321. is_causal=False,
  322. layer_idx=layer_idx,
  323. is_cross_attention=True,
  324. )
  325. self.intermediate = ElectraIntermediate(config)
  326. self.output = ElectraOutput(config)
  327. def forward(
  328. self,
  329. hidden_states: torch.Tensor,
  330. attention_mask: torch.FloatTensor | None = None,
  331. encoder_hidden_states: torch.FloatTensor | None = None,
  332. encoder_attention_mask: torch.FloatTensor | None = None,
  333. past_key_values: Cache | None = None,
  334. **kwargs: Unpack[TransformersKwargs],
  335. ) -> torch.Tensor:
  336. self_attention_output, _ = self.attention(
  337. hidden_states,
  338. attention_mask,
  339. past_key_values=past_key_values,
  340. **kwargs,
  341. )
  342. attention_output = self_attention_output
  343. if self.is_decoder and encoder_hidden_states is not None:
  344. if not hasattr(self, "crossattention"):
  345. raise ValueError(
  346. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  347. " by setting `config.add_cross_attention=True`"
  348. )
  349. cross_attention_output, _ = self.crossattention(
  350. self_attention_output,
  351. None, # attention_mask
  352. encoder_hidden_states,
  353. encoder_attention_mask,
  354. past_key_values=past_key_values,
  355. **kwargs,
  356. )
  357. attention_output = cross_attention_output
  358. layer_output = apply_chunking_to_forward(
  359. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  360. )
  361. return layer_output
  362. def feed_forward_chunk(self, attention_output):
  363. intermediate_output = self.intermediate(attention_output)
  364. layer_output = self.output(intermediate_output, attention_output)
  365. return layer_output
  366. # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Electra
  367. class ElectraEncoder(nn.Module):
  368. def __init__(self, config):
  369. super().__init__()
  370. self.config = config
  371. self.layer = nn.ModuleList([ElectraLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  372. def forward(
  373. self,
  374. hidden_states: torch.Tensor,
  375. attention_mask: torch.FloatTensor | None = None,
  376. encoder_hidden_states: torch.FloatTensor | None = None,
  377. encoder_attention_mask: torch.FloatTensor | None = None,
  378. past_key_values: Cache | None = None,
  379. use_cache: bool | None = None,
  380. **kwargs: Unpack[TransformersKwargs],
  381. ) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
  382. for i, layer_module in enumerate(self.layer):
  383. hidden_states = layer_module(
  384. hidden_states,
  385. attention_mask,
  386. encoder_hidden_states, # as a positional argument for gradient checkpointing
  387. encoder_attention_mask=encoder_attention_mask,
  388. past_key_values=past_key_values,
  389. **kwargs,
  390. )
  391. return BaseModelOutputWithPastAndCrossAttentions(
  392. last_hidden_state=hidden_states,
  393. past_key_values=past_key_values if use_cache else None,
  394. )
  395. class ElectraDiscriminatorPredictions(nn.Module):
  396. """Prediction module for the discriminator, made up of two dense layers."""
  397. def __init__(self, config):
  398. super().__init__()
  399. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  400. self.activation = get_activation(config.hidden_act)
  401. self.dense_prediction = nn.Linear(config.hidden_size, 1)
  402. self.config = config
  403. def forward(self, discriminator_hidden_states):
  404. hidden_states = self.dense(discriminator_hidden_states)
  405. hidden_states = self.activation(hidden_states)
  406. logits = self.dense_prediction(hidden_states).squeeze(-1)
  407. return logits
  408. class ElectraGeneratorPredictions(nn.Module):
  409. """Prediction module for the generator, made up of two dense layers."""
  410. def __init__(self, config):
  411. super().__init__()
  412. self.activation = get_activation("gelu")
  413. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  414. self.dense = nn.Linear(config.hidden_size, config.embedding_size)
  415. def forward(self, generator_hidden_states):
  416. hidden_states = self.dense(generator_hidden_states)
  417. hidden_states = self.activation(hidden_states)
  418. hidden_states = self.LayerNorm(hidden_states)
  419. return hidden_states
  420. @auto_docstring
  421. class ElectraPreTrainedModel(PreTrainedModel):
  422. config_class = ElectraConfig
  423. base_model_prefix = "electra"
  424. supports_gradient_checkpointing = True
  425. _supports_flash_attn = True
  426. _supports_sdpa = True
  427. _supports_flex_attn = True
  428. _supports_attention_backend = True
  429. _can_record_outputs = {
  430. "hidden_states": ElectraLayer,
  431. "attentions": ElectraSelfAttention,
  432. "cross_attentions": ElectraCrossAttention,
  433. }
  434. def _init_weights(self, module):
  435. super()._init_weights(module)
  436. if isinstance(module, ElectraEmbeddings):
  437. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  438. init.zeros_(module.token_type_ids)
  439. @dataclass
  440. @auto_docstring(
  441. custom_intro="""
  442. Output type of [`ElectraForPreTraining`].
  443. """
  444. )
  445. class ElectraForPreTrainingOutput(ModelOutput):
  446. r"""
  447. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  448. Total loss of the ELECTRA objective.
  449. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  450. Prediction scores of the head (scores for each token before SoftMax).
  451. """
  452. loss: torch.FloatTensor | None = None
  453. logits: torch.FloatTensor | None = None
  454. hidden_states: tuple[torch.FloatTensor] | None = None
  455. attentions: tuple[torch.FloatTensor] | None = None
  456. @auto_docstring
  457. class ElectraModel(ElectraPreTrainedModel):
  458. def __init__(self, config):
  459. super().__init__(config)
  460. self.embeddings = ElectraEmbeddings(config)
  461. if config.embedding_size != config.hidden_size:
  462. self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
  463. self.encoder = ElectraEncoder(config)
  464. self.config = config
  465. self.gradient_checkpointing = False
  466. # Initialize weights and apply final processing
  467. self.post_init()
  468. def get_input_embeddings(self):
  469. return self.embeddings.word_embeddings
  470. def set_input_embeddings(self, value):
  471. self.embeddings.word_embeddings = value
  472. @merge_with_config_defaults
  473. @capture_outputs
  474. @auto_docstring
  475. def forward(
  476. self,
  477. input_ids: torch.Tensor | None = None,
  478. attention_mask: torch.Tensor | None = None,
  479. token_type_ids: torch.Tensor | None = None,
  480. position_ids: torch.Tensor | None = None,
  481. inputs_embeds: torch.Tensor | None = None,
  482. encoder_hidden_states: torch.Tensor | None = None,
  483. encoder_attention_mask: torch.Tensor | None = None,
  484. past_key_values: list[torch.FloatTensor] | None = None,
  485. use_cache: bool | None = None,
  486. **kwargs: Unpack[TransformersKwargs],
  487. ) -> tuple[torch.Tensor] | BaseModelOutputWithCrossAttentions:
  488. if (input_ids is None) ^ (inputs_embeds is not None):
  489. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  490. if self.config.is_decoder:
  491. use_cache = use_cache if use_cache is not None else self.config.use_cache
  492. else:
  493. use_cache = False
  494. if use_cache and past_key_values is None:
  495. past_key_values = (
  496. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  497. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  498. else DynamicCache(config=self.config)
  499. )
  500. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  501. embedding_output = self.embeddings(
  502. input_ids=input_ids,
  503. position_ids=position_ids,
  504. token_type_ids=token_type_ids,
  505. inputs_embeds=inputs_embeds,
  506. past_key_values_length=past_key_values_length,
  507. )
  508. if hasattr(self, "embeddings_project"):
  509. embedding_output = self.embeddings_project(embedding_output)
  510. attention_mask, encoder_attention_mask = self._create_attention_masks(
  511. attention_mask=attention_mask,
  512. encoder_attention_mask=encoder_attention_mask,
  513. embedding_output=embedding_output,
  514. encoder_hidden_states=encoder_hidden_states,
  515. past_key_values=past_key_values,
  516. )
  517. encoder_outputs = self.encoder(
  518. embedding_output,
  519. attention_mask=attention_mask,
  520. encoder_hidden_states=encoder_hidden_states,
  521. encoder_attention_mask=encoder_attention_mask,
  522. past_key_values=past_key_values,
  523. use_cache=use_cache,
  524. position_ids=position_ids,
  525. **kwargs,
  526. )
  527. return BaseModelOutputWithPastAndCrossAttentions(
  528. last_hidden_state=encoder_outputs.last_hidden_state,
  529. past_key_values=encoder_outputs.past_key_values,
  530. )
  531. # Copied from transformers.models.bert.modeling_bert.BertModel._create_attention_masks
  532. def _create_attention_masks(
  533. self,
  534. attention_mask,
  535. encoder_attention_mask,
  536. embedding_output,
  537. encoder_hidden_states,
  538. past_key_values,
  539. ):
  540. if self.config.is_decoder:
  541. attention_mask = create_causal_mask(
  542. config=self.config,
  543. inputs_embeds=embedding_output,
  544. attention_mask=attention_mask,
  545. past_key_values=past_key_values,
  546. )
  547. else:
  548. attention_mask = create_bidirectional_mask(
  549. config=self.config,
  550. inputs_embeds=embedding_output,
  551. attention_mask=attention_mask,
  552. )
  553. if encoder_attention_mask is not None:
  554. encoder_attention_mask = create_bidirectional_mask(
  555. config=self.config,
  556. inputs_embeds=embedding_output,
  557. attention_mask=encoder_attention_mask,
  558. encoder_hidden_states=encoder_hidden_states,
  559. )
  560. return attention_mask, encoder_attention_mask
  561. class ElectraClassificationHead(nn.Module):
  562. """Head for sentence-level classification tasks."""
  563. def __init__(self, config):
  564. super().__init__()
  565. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  566. classifier_dropout = (
  567. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  568. )
  569. self.activation = get_activation("gelu")
  570. self.dropout = nn.Dropout(classifier_dropout)
  571. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  572. def forward(self, features, **kwargs):
  573. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  574. x = self.dropout(x)
  575. x = self.dense(x)
  576. x = self.activation(x) # although BERT uses tanh here, it seems Electra authors used gelu here
  577. x = self.dropout(x)
  578. x = self.out_proj(x)
  579. return x
  580. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Electra
  581. class ElectraSequenceSummary(nn.Module):
  582. r"""
  583. Compute a single vector summary of a sequence hidden states.
  584. Args:
  585. config ([`ElectraConfig`]):
  586. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  587. config class of your model for the default values it uses):
  588. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  589. - `"last"` -- Take the last token hidden state (like XLNet)
  590. - `"first"` -- Take the first token hidden state (like Bert)
  591. - `"mean"` -- Take the mean of all tokens hidden states
  592. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  593. - `"attn"` -- Not implemented now, use multi-head attention
  594. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  595. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  596. (otherwise to `config.hidden_size`).
  597. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  598. another string or `None` will add no activation.
  599. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  600. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  601. """
  602. def __init__(self, config: ElectraConfig):
  603. super().__init__()
  604. self.summary_type = getattr(config, "summary_type", "last")
  605. if self.summary_type == "attn":
  606. # We should use a standard multi-head attention module with absolute positional embedding for that.
  607. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  608. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  609. raise NotImplementedError
  610. self.summary = nn.Identity()
  611. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  612. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  613. num_classes = config.num_labels
  614. else:
  615. num_classes = config.hidden_size
  616. self.summary = nn.Linear(config.hidden_size, num_classes)
  617. activation_string = getattr(config, "summary_activation", None)
  618. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  619. self.first_dropout = nn.Identity()
  620. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  621. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  622. self.last_dropout = nn.Identity()
  623. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  624. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  625. def forward(
  626. self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
  627. ) -> torch.FloatTensor:
  628. """
  629. Compute a single vector summary of a sequence hidden states.
  630. Args:
  631. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  632. The hidden states of the last layer.
  633. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  634. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  635. Returns:
  636. `torch.FloatTensor`: The summary of the sequence hidden states.
  637. """
  638. if self.summary_type == "last":
  639. output = hidden_states[:, -1]
  640. elif self.summary_type == "first":
  641. output = hidden_states[:, 0]
  642. elif self.summary_type == "mean":
  643. output = hidden_states.mean(dim=1)
  644. elif self.summary_type == "cls_index":
  645. if cls_index is None:
  646. cls_index = torch.full_like(
  647. hidden_states[..., :1, :],
  648. hidden_states.shape[-2] - 1,
  649. dtype=torch.long,
  650. )
  651. else:
  652. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  653. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  654. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  655. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  656. elif self.summary_type == "attn":
  657. raise NotImplementedError
  658. output = self.first_dropout(output)
  659. output = self.summary(output)
  660. output = self.activation(output)
  661. output = self.last_dropout(output)
  662. return output
  663. @auto_docstring(
  664. custom_intro="""
  665. ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  666. pooled output) e.g. for GLUE tasks.
  667. """
  668. )
  669. class ElectraForSequenceClassification(ElectraPreTrainedModel):
  670. def __init__(self, config):
  671. super().__init__(config)
  672. self.num_labels = config.num_labels
  673. self.config = config
  674. self.electra = ElectraModel(config)
  675. self.classifier = ElectraClassificationHead(config)
  676. # Initialize weights and apply final processing
  677. self.post_init()
  678. @can_return_tuple
  679. @auto_docstring
  680. def forward(
  681. self,
  682. input_ids: torch.Tensor | None = None,
  683. attention_mask: torch.Tensor | None = None,
  684. token_type_ids: torch.Tensor | None = None,
  685. position_ids: torch.Tensor | None = None,
  686. inputs_embeds: torch.Tensor | None = None,
  687. labels: torch.Tensor | None = None,
  688. **kwargs: Unpack[TransformersKwargs],
  689. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  690. r"""
  691. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  692. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  693. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  694. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  695. """
  696. discriminator_hidden_states = self.electra(
  697. input_ids,
  698. attention_mask=attention_mask,
  699. token_type_ids=token_type_ids,
  700. position_ids=position_ids,
  701. inputs_embeds=inputs_embeds,
  702. return_dict=True,
  703. **kwargs,
  704. )
  705. sequence_output = discriminator_hidden_states[0]
  706. logits = self.classifier(sequence_output)
  707. loss = None
  708. if labels is not None:
  709. if self.config.problem_type is None:
  710. if self.num_labels == 1:
  711. self.config.problem_type = "regression"
  712. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  713. self.config.problem_type = "single_label_classification"
  714. else:
  715. self.config.problem_type = "multi_label_classification"
  716. if self.config.problem_type == "regression":
  717. loss_fct = MSELoss()
  718. if self.num_labels == 1:
  719. loss = loss_fct(logits.squeeze(), labels.squeeze())
  720. else:
  721. loss = loss_fct(logits, labels)
  722. elif self.config.problem_type == "single_label_classification":
  723. loss_fct = CrossEntropyLoss()
  724. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  725. elif self.config.problem_type == "multi_label_classification":
  726. loss_fct = BCEWithLogitsLoss()
  727. loss = loss_fct(logits, labels)
  728. return SequenceClassifierOutput(
  729. loss=loss,
  730. logits=logits,
  731. hidden_states=discriminator_hidden_states.hidden_states,
  732. attentions=discriminator_hidden_states.attentions,
  733. )
  734. @auto_docstring(
  735. custom_intro="""
  736. Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
  737. It is recommended to load the discriminator checkpoint into that model.
  738. """
  739. )
  740. class ElectraForPreTraining(ElectraPreTrainedModel):
  741. def __init__(self, config):
  742. super().__init__(config)
  743. self.electra = ElectraModel(config)
  744. self.discriminator_predictions = ElectraDiscriminatorPredictions(config)
  745. # Initialize weights and apply final processing
  746. self.post_init()
  747. @can_return_tuple
  748. @auto_docstring
  749. def forward(
  750. self,
  751. input_ids: torch.Tensor | None = None,
  752. attention_mask: torch.Tensor | None = None,
  753. token_type_ids: torch.Tensor | None = None,
  754. position_ids: torch.Tensor | None = None,
  755. inputs_embeds: torch.Tensor | None = None,
  756. labels: torch.Tensor | None = None,
  757. **kwargs: Unpack[TransformersKwargs],
  758. ) -> tuple[torch.Tensor] | ElectraForPreTrainingOutput:
  759. r"""
  760. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  761. Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)
  762. Indices should be in `[0, 1]`:
  763. - 0 indicates the token is an original token,
  764. - 1 indicates the token was replaced.
  765. Examples:
  766. ```python
  767. >>> from transformers import ElectraForPreTraining, AutoTokenizer
  768. >>> import torch
  769. >>> discriminator = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")
  770. >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")
  771. >>> sentence = "The quick brown fox jumps over the lazy dog"
  772. >>> fake_sentence = "The quick brown fox fake over the lazy dog"
  773. >>> fake_tokens = tokenizer.tokenize(fake_sentence, add_special_tokens=True)
  774. >>> fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
  775. >>> discriminator_outputs = discriminator(fake_inputs)
  776. >>> predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
  777. >>> fake_tokens
  778. ['[CLS]', 'the', 'quick', 'brown', 'fox', 'fake', 'over', 'the', 'lazy', 'dog', '[SEP]']
  779. >>> predictions.squeeze().tolist()
  780. [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  781. ```"""
  782. discriminator_hidden_states = self.electra(
  783. input_ids,
  784. attention_mask=attention_mask,
  785. token_type_ids=token_type_ids,
  786. position_ids=position_ids,
  787. inputs_embeds=inputs_embeds,
  788. return_dict=True,
  789. **kwargs,
  790. )
  791. discriminator_sequence_output = discriminator_hidden_states[0]
  792. logits = self.discriminator_predictions(discriminator_sequence_output)
  793. loss = None
  794. if labels is not None:
  795. loss_fct = nn.BCEWithLogitsLoss()
  796. if attention_mask is not None:
  797. active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
  798. active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
  799. active_labels = labels[active_loss]
  800. loss = loss_fct(active_logits, active_labels.float())
  801. else:
  802. loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
  803. return ElectraForPreTrainingOutput(
  804. loss=loss,
  805. logits=logits,
  806. hidden_states=discriminator_hidden_states.hidden_states,
  807. attentions=discriminator_hidden_states.attentions,
  808. )
  809. @auto_docstring(
  810. custom_intro="""
  811. Electra model with a language modeling head on top.
  812. Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
  813. the two to have been trained for the masked language modeling task.
  814. """
  815. )
  816. class ElectraForMaskedLM(ElectraPreTrainedModel):
  817. _tied_weights_keys = {"generator_lm_head.weight": "electra.embeddings.word_embeddings.weight"}
  818. def __init__(self, config):
  819. super().__init__(config)
  820. self.electra = ElectraModel(config)
  821. self.generator_predictions = ElectraGeneratorPredictions(config)
  822. self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
  823. # Initialize weights and apply final processing
  824. self.post_init()
  825. def get_output_embeddings(self):
  826. return self.generator_lm_head
  827. def set_output_embeddings(self, word_embeddings):
  828. self.generator_lm_head = word_embeddings
  829. @can_return_tuple
  830. @auto_docstring
  831. def forward(
  832. self,
  833. input_ids: torch.Tensor | None = None,
  834. attention_mask: torch.Tensor | None = None,
  835. token_type_ids: torch.Tensor | None = None,
  836. position_ids: torch.Tensor | None = None,
  837. inputs_embeds: torch.Tensor | None = None,
  838. labels: torch.Tensor | None = None,
  839. **kwargs: Unpack[TransformersKwargs],
  840. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  841. r"""
  842. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  843. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  844. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  845. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  846. """
  847. generator_hidden_states = self.electra(
  848. input_ids,
  849. attention_mask=attention_mask,
  850. token_type_ids=token_type_ids,
  851. position_ids=position_ids,
  852. inputs_embeds=inputs_embeds,
  853. return_dict=True,
  854. **kwargs,
  855. )
  856. generator_sequence_output = generator_hidden_states[0]
  857. prediction_scores = self.generator_predictions(generator_sequence_output)
  858. prediction_scores = self.generator_lm_head(prediction_scores)
  859. loss = None
  860. # Masked language modeling softmax layer
  861. if labels is not None:
  862. loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
  863. loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  864. return MaskedLMOutput(
  865. loss=loss,
  866. logits=prediction_scores,
  867. hidden_states=generator_hidden_states.hidden_states,
  868. attentions=generator_hidden_states.attentions,
  869. )
  870. @auto_docstring(
  871. custom_intro="""
  872. Electra model with a token classification head on top.
  873. Both the discriminator and generator may be loaded into this model.
  874. """
  875. )
  876. class ElectraForTokenClassification(ElectraPreTrainedModel):
  877. def __init__(self, config):
  878. super().__init__(config)
  879. self.num_labels = config.num_labels
  880. self.electra = ElectraModel(config)
  881. classifier_dropout = (
  882. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  883. )
  884. self.dropout = nn.Dropout(classifier_dropout)
  885. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  886. # Initialize weights and apply final processing
  887. self.post_init()
  888. @can_return_tuple
  889. @auto_docstring
  890. def forward(
  891. self,
  892. input_ids: torch.Tensor | None = None,
  893. attention_mask: torch.Tensor | None = None,
  894. token_type_ids: torch.Tensor | None = None,
  895. position_ids: torch.Tensor | None = None,
  896. inputs_embeds: torch.Tensor | None = None,
  897. labels: torch.Tensor | None = None,
  898. **kwargs: Unpack[TransformersKwargs],
  899. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  900. r"""
  901. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  902. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  903. """
  904. discriminator_hidden_states = self.electra(
  905. input_ids,
  906. attention_mask=attention_mask,
  907. token_type_ids=token_type_ids,
  908. position_ids=position_ids,
  909. inputs_embeds=inputs_embeds,
  910. return_dict=True,
  911. **kwargs,
  912. )
  913. discriminator_sequence_output = discriminator_hidden_states[0]
  914. discriminator_sequence_output = self.dropout(discriminator_sequence_output)
  915. logits = self.classifier(discriminator_sequence_output)
  916. loss = None
  917. if labels is not None:
  918. loss_fct = CrossEntropyLoss()
  919. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  920. return TokenClassifierOutput(
  921. loss=loss,
  922. logits=logits,
  923. hidden_states=discriminator_hidden_states.hidden_states,
  924. attentions=discriminator_hidden_states.attentions,
  925. )
  926. @auto_docstring
  927. class ElectraForQuestionAnswering(ElectraPreTrainedModel):
  928. config_class = ElectraConfig
  929. base_model_prefix = "electra"
  930. def __init__(self, config):
  931. super().__init__(config)
  932. self.num_labels = config.num_labels
  933. self.electra = ElectraModel(config)
  934. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  935. # Initialize weights and apply final processing
  936. self.post_init()
  937. @can_return_tuple
  938. @auto_docstring
  939. def forward(
  940. self,
  941. input_ids: torch.Tensor | None = None,
  942. attention_mask: torch.Tensor | None = None,
  943. token_type_ids: torch.Tensor | None = None,
  944. position_ids: torch.Tensor | None = None,
  945. inputs_embeds: torch.Tensor | None = None,
  946. start_positions: torch.Tensor | None = None,
  947. end_positions: torch.Tensor | None = None,
  948. **kwargs: Unpack[TransformersKwargs],
  949. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  950. discriminator_hidden_states = self.electra(
  951. input_ids,
  952. attention_mask=attention_mask,
  953. token_type_ids=token_type_ids,
  954. position_ids=position_ids,
  955. inputs_embeds=inputs_embeds,
  956. return_dict=True,
  957. **kwargs,
  958. )
  959. sequence_output = discriminator_hidden_states[0]
  960. logits = self.qa_outputs(sequence_output)
  961. start_logits, end_logits = logits.split(1, dim=-1)
  962. start_logits = start_logits.squeeze(-1).contiguous()
  963. end_logits = end_logits.squeeze(-1).contiguous()
  964. total_loss = None
  965. if start_positions is not None and end_positions is not None:
  966. # If we are on multi-GPU, split add a dimension
  967. if len(start_positions.size()) > 1:
  968. start_positions = start_positions.squeeze(-1)
  969. if len(end_positions.size()) > 1:
  970. end_positions = end_positions.squeeze(-1)
  971. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  972. ignored_index = start_logits.size(1)
  973. start_positions = start_positions.clamp(0, ignored_index)
  974. end_positions = end_positions.clamp(0, ignored_index)
  975. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  976. start_loss = loss_fct(start_logits, start_positions)
  977. end_loss = loss_fct(end_logits, end_positions)
  978. total_loss = (start_loss + end_loss) / 2
  979. return QuestionAnsweringModelOutput(
  980. loss=total_loss,
  981. start_logits=start_logits,
  982. end_logits=end_logits,
  983. hidden_states=discriminator_hidden_states.hidden_states,
  984. attentions=discriminator_hidden_states.attentions,
  985. )
  986. @auto_docstring
  987. class ElectraForMultipleChoice(ElectraPreTrainedModel):
  988. def __init__(self, config):
  989. super().__init__(config)
  990. self.electra = ElectraModel(config)
  991. self.sequence_summary = ElectraSequenceSummary(config)
  992. self.classifier = nn.Linear(config.hidden_size, 1)
  993. # Initialize weights and apply final processing
  994. self.post_init()
  995. @can_return_tuple
  996. @auto_docstring
  997. def forward(
  998. self,
  999. input_ids: torch.Tensor | None = None,
  1000. attention_mask: torch.Tensor | None = None,
  1001. token_type_ids: torch.Tensor | None = None,
  1002. position_ids: torch.Tensor | None = None,
  1003. inputs_embeds: torch.Tensor | None = None,
  1004. labels: torch.Tensor | None = None,
  1005. **kwargs: Unpack[TransformersKwargs],
  1006. ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
  1007. r"""
  1008. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1009. Indices of input sequence tokens in the vocabulary.
  1010. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1011. [`PreTrainedTokenizer.__call__`] for details.
  1012. [What are input IDs?](../glossary#input-ids)
  1013. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1014. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1015. 1]`:
  1016. - 0 corresponds to a *sentence A* token,
  1017. - 1 corresponds to a *sentence B* token.
  1018. [What are token type IDs?](../glossary#token-type-ids)
  1019. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1020. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1021. config.max_position_embeddings - 1]`.
  1022. [What are position IDs?](../glossary#position-ids)
  1023. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1024. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1025. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1026. model's internal embedding lookup matrix.
  1027. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1028. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1029. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1030. `input_ids` above)
  1031. """
  1032. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1033. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1034. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1035. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1036. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1037. inputs_embeds = (
  1038. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1039. if inputs_embeds is not None
  1040. else None
  1041. )
  1042. discriminator_hidden_states = self.electra(
  1043. input_ids,
  1044. attention_mask=attention_mask,
  1045. token_type_ids=token_type_ids,
  1046. position_ids=position_ids,
  1047. inputs_embeds=inputs_embeds,
  1048. return_dict=True,
  1049. **kwargs,
  1050. )
  1051. sequence_output = discriminator_hidden_states[0]
  1052. pooled_output = self.sequence_summary(sequence_output)
  1053. logits = self.classifier(pooled_output)
  1054. reshaped_logits = logits.view(-1, num_choices)
  1055. loss = None
  1056. if labels is not None:
  1057. loss_fct = CrossEntropyLoss()
  1058. loss = loss_fct(reshaped_logits, labels)
  1059. return MultipleChoiceModelOutput(
  1060. loss=loss,
  1061. logits=reshaped_logits,
  1062. hidden_states=discriminator_hidden_states.hidden_states,
  1063. attentions=discriminator_hidden_states.attentions,
  1064. )
  1065. @auto_docstring(
  1066. custom_intro="""
  1067. ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.
  1068. """
  1069. )
  1070. class ElectraForCausalLM(ElectraPreTrainedModel, GenerationMixin):
  1071. _tied_weights_keys = {"generator_lm_head.weight": "electra.embeddings.word_embeddings.weight"}
  1072. def __init__(self, config):
  1073. super().__init__(config)
  1074. if not config.is_decoder:
  1075. logger.warning("If you want to use `ElectraForCausalLM` as a standalone, add `is_decoder=True.`")
  1076. self.electra = ElectraModel(config)
  1077. self.generator_predictions = ElectraGeneratorPredictions(config)
  1078. self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
  1079. self.post_init()
  1080. def get_output_embeddings(self):
  1081. return self.generator_lm_head
  1082. def set_output_embeddings(self, new_embeddings):
  1083. self.generator_lm_head = new_embeddings
  1084. @can_return_tuple
  1085. @auto_docstring
  1086. def forward(
  1087. self,
  1088. input_ids: torch.Tensor | None = None,
  1089. attention_mask: torch.Tensor | None = None,
  1090. token_type_ids: torch.Tensor | None = None,
  1091. position_ids: torch.Tensor | None = None,
  1092. inputs_embeds: torch.Tensor | None = None,
  1093. encoder_hidden_states: torch.Tensor | None = None,
  1094. encoder_attention_mask: torch.Tensor | None = None,
  1095. labels: torch.Tensor | None = None,
  1096. past_key_values: Cache | None = None,
  1097. use_cache: bool | None = None,
  1098. logits_to_keep: int | torch.Tensor = 0,
  1099. **kwargs: Unpack[TransformersKwargs],
  1100. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  1101. r"""
  1102. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1103. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1104. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1105. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1106. Example:
  1107. ```python
  1108. >>> from transformers import AutoTokenizer, ElectraForCausalLM, ElectraConfig
  1109. >>> import torch
  1110. >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-generator")
  1111. >>> config = ElectraConfig.from_pretrained("google/electra-base-generator")
  1112. >>> config.is_decoder = True
  1113. >>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config)
  1114. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1115. >>> outputs = model(**inputs)
  1116. >>> prediction_logits = outputs.logits
  1117. ```"""
  1118. if labels is not None:
  1119. use_cache = False
  1120. outputs: BaseModelOutputWithPastAndCrossAttentions = self.electra(
  1121. input_ids,
  1122. attention_mask=attention_mask,
  1123. token_type_ids=token_type_ids,
  1124. position_ids=position_ids,
  1125. inputs_embeds=inputs_embeds,
  1126. encoder_hidden_states=encoder_hidden_states,
  1127. encoder_attention_mask=encoder_attention_mask,
  1128. past_key_values=past_key_values,
  1129. use_cache=use_cache,
  1130. return_dict=True,
  1131. **kwargs,
  1132. )
  1133. hidden_states = outputs.last_hidden_state
  1134. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1135. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1136. logits = self.generator_lm_head(self.generator_predictions(hidden_states[:, slice_indices, :]))
  1137. loss = None
  1138. if labels is not None:
  1139. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  1140. return CausalLMOutputWithCrossAttentions(
  1141. loss=loss,
  1142. logits=logits,
  1143. past_key_values=outputs.past_key_values,
  1144. hidden_states=outputs.hidden_states,
  1145. attentions=outputs.attentions,
  1146. cross_attentions=outputs.cross_attentions,
  1147. )
  1148. __all__ = [
  1149. "ElectraForCausalLM",
  1150. "ElectraForMaskedLM",
  1151. "ElectraForMultipleChoice",
  1152. "ElectraForPreTraining",
  1153. "ElectraForQuestionAnswering",
  1154. "ElectraForSequenceClassification",
  1155. "ElectraForTokenClassification",
  1156. "ElectraModel",
  1157. "ElectraPreTrainedModel",
  1158. ]