modeling_rembert.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142
  1. # Copyright 2021 The HuggingFace Team The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch RemBERT model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  22. from ...generation import GenerationMixin
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithPastAndCrossAttentions,
  26. BaseModelOutputWithPoolingAndCrossAttentions,
  27. CausalLMOutputWithCrossAttentions,
  28. MaskedLMOutput,
  29. MultipleChoiceModelOutput,
  30. QuestionAnsweringModelOutput,
  31. SequenceClassifierOutput,
  32. TokenClassifierOutput,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...pytorch_utils import apply_chunking_to_forward
  36. from ...utils import auto_docstring, logging
  37. from .configuration_rembert import RemBertConfig
  38. logger = logging.get_logger(__name__)
  39. class RemBertEmbeddings(nn.Module):
  40. """Construct the embeddings from word, position and token_type embeddings."""
  41. def __init__(self, config):
  42. super().__init__()
  43. self.word_embeddings = nn.Embedding(
  44. config.vocab_size, config.input_embedding_size, padding_idx=config.pad_token_id
  45. )
  46. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.input_embedding_size)
  47. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.input_embedding_size)
  48. self.LayerNorm = nn.LayerNorm(config.input_embedding_size, eps=config.layer_norm_eps)
  49. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  50. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  51. self.register_buffer(
  52. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  53. )
  54. def forward(
  55. self,
  56. input_ids: torch.LongTensor | None = None,
  57. token_type_ids: torch.LongTensor | None = None,
  58. position_ids: torch.LongTensor | None = None,
  59. inputs_embeds: torch.FloatTensor | None = None,
  60. past_key_values_length: int = 0,
  61. ) -> torch.Tensor:
  62. if input_ids is not None:
  63. input_shape = input_ids.size()
  64. else:
  65. input_shape = inputs_embeds.size()[:-1]
  66. seq_length = input_shape[1]
  67. if position_ids is None:
  68. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  69. if token_type_ids is None:
  70. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  71. if inputs_embeds is None:
  72. inputs_embeds = self.word_embeddings(input_ids)
  73. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  74. embeddings = inputs_embeds + token_type_embeddings
  75. position_embeddings = self.position_embeddings(position_ids)
  76. embeddings += position_embeddings
  77. embeddings = self.LayerNorm(embeddings)
  78. embeddings = self.dropout(embeddings)
  79. return embeddings
  80. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->RemBert
  81. class RemBertPooler(nn.Module):
  82. def __init__(self, config):
  83. super().__init__()
  84. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  85. self.activation = nn.Tanh()
  86. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  87. # We "pool" the model by simply taking the hidden state corresponding
  88. # to the first token.
  89. first_token_tensor = hidden_states[:, 0]
  90. pooled_output = self.dense(first_token_tensor)
  91. pooled_output = self.activation(pooled_output)
  92. return pooled_output
  93. class RemBertSelfAttention(nn.Module):
  94. def __init__(self, config, layer_idx=None):
  95. super().__init__()
  96. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  97. raise ValueError(
  98. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  99. f"heads ({config.num_attention_heads})"
  100. )
  101. self.num_attention_heads = config.num_attention_heads
  102. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  103. self.all_head_size = self.num_attention_heads * self.attention_head_size
  104. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  105. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  106. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  107. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  108. self.is_decoder = config.is_decoder
  109. self.layer_idx = layer_idx
  110. def forward(
  111. self,
  112. hidden_states: torch.Tensor,
  113. attention_mask: torch.FloatTensor | None = None,
  114. encoder_hidden_states: torch.FloatTensor | None = None,
  115. past_key_values: Cache | None = None,
  116. output_attentions: bool = False,
  117. **kwargs,
  118. ) -> tuple:
  119. input_shape = hidden_states.shape[:-1]
  120. hidden_shape = (*input_shape, -1, self.attention_head_size)
  121. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  122. is_updated = False
  123. is_cross_attention = encoder_hidden_states is not None
  124. if past_key_values is not None:
  125. if isinstance(past_key_values, EncoderDecoderCache):
  126. is_updated = past_key_values.is_updated.get(self.layer_idx)
  127. if is_cross_attention:
  128. # after the first generated id, we can subsequently re-use all key/value_layer from cache
  129. curr_past_key_values = past_key_values.cross_attention_cache
  130. else:
  131. curr_past_key_values = past_key_values.self_attention_cache
  132. else:
  133. curr_past_key_values = past_key_values
  134. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  135. if is_cross_attention and past_key_values is not None and is_updated:
  136. # reuse k,v, cross_attentions
  137. key_layer = curr_past_key_values.layers[self.layer_idx].keys
  138. value_layer = curr_past_key_values.layers[self.layer_idx].values
  139. else:
  140. kv_shape = (*current_states.shape[:-1], -1, self.attention_head_size)
  141. key_layer = self.key(current_states).view(kv_shape).transpose(1, 2)
  142. value_layer = self.value(current_states).view(kv_shape).transpose(1, 2)
  143. if past_key_values is not None:
  144. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  145. key_layer, value_layer = curr_past_key_values.update(key_layer, value_layer, self.layer_idx)
  146. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  147. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  148. past_key_values.is_updated[self.layer_idx] = True
  149. # Take the dot product between "query" and "key" to get the raw attention scores.
  150. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  151. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  152. if attention_mask is not None:
  153. # Apply the attention mask is (precomputed for all layers in RemBertModel forward() function)
  154. attention_scores = attention_scores + attention_mask
  155. # Normalize the attention scores to probabilities.
  156. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  157. # This is actually dropping out entire tokens to attend to, which might
  158. # seem a bit unusual, but is taken from the original Transformer paper.
  159. attention_probs = self.dropout(attention_probs)
  160. context_layer = torch.matmul(attention_probs, value_layer)
  161. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  162. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  163. context_layer = context_layer.view(*new_context_layer_shape)
  164. return context_layer, attention_probs
  165. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RemBert
  166. class RemBertSelfOutput(nn.Module):
  167. def __init__(self, config):
  168. super().__init__()
  169. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  170. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  171. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  172. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  173. hidden_states = self.dense(hidden_states)
  174. hidden_states = self.dropout(hidden_states)
  175. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  176. return hidden_states
  177. class RemBertAttention(nn.Module):
  178. def __init__(self, config, layer_idx=None):
  179. super().__init__()
  180. self.self = RemBertSelfAttention(config, layer_idx=layer_idx)
  181. self.output = RemBertSelfOutput(config)
  182. def forward(
  183. self,
  184. hidden_states: torch.Tensor,
  185. attention_mask: torch.FloatTensor | None = None,
  186. encoder_hidden_states: torch.FloatTensor | None = None,
  187. past_key_values: Cache | None = None,
  188. output_attentions: bool | None = False,
  189. **kwargs,
  190. ) -> tuple[torch.Tensor]:
  191. self_outputs = self.self(
  192. hidden_states,
  193. attention_mask=attention_mask,
  194. encoder_hidden_states=encoder_hidden_states,
  195. past_key_values=past_key_values,
  196. output_attentions=output_attentions,
  197. )
  198. attention_output = self.output(self_outputs[0], hidden_states)
  199. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  200. return outputs
  201. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RemBert
  202. class RemBertIntermediate(nn.Module):
  203. def __init__(self, config):
  204. super().__init__()
  205. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  206. if isinstance(config.hidden_act, str):
  207. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  208. else:
  209. self.intermediate_act_fn = config.hidden_act
  210. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  211. hidden_states = self.dense(hidden_states)
  212. hidden_states = self.intermediate_act_fn(hidden_states)
  213. return hidden_states
  214. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RemBert
  215. class RemBertOutput(nn.Module):
  216. def __init__(self, config):
  217. super().__init__()
  218. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  219. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  220. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  221. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  222. hidden_states = self.dense(hidden_states)
  223. hidden_states = self.dropout(hidden_states)
  224. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  225. return hidden_states
  226. class RemBertLayer(GradientCheckpointingLayer):
  227. def __init__(self, config, layer_idx=None):
  228. super().__init__()
  229. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  230. self.seq_len_dim = 1
  231. self.attention = RemBertAttention(config, layer_idx)
  232. self.is_decoder = config.is_decoder
  233. self.add_cross_attention = config.add_cross_attention
  234. if self.add_cross_attention:
  235. if not self.is_decoder:
  236. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  237. self.crossattention = RemBertAttention(config, layer_idx=layer_idx)
  238. self.intermediate = RemBertIntermediate(config)
  239. self.output = RemBertOutput(config)
  240. # copied from transformers.models.bert.modeling_bert.BertLayer.forward
  241. def forward(
  242. self,
  243. hidden_states: torch.Tensor,
  244. attention_mask: torch.FloatTensor | None = None,
  245. encoder_hidden_states: torch.FloatTensor | None = None,
  246. encoder_attention_mask: torch.FloatTensor | None = None,
  247. past_key_values: Cache | None = None,
  248. output_attentions: bool | None = False,
  249. **kwargs,
  250. ) -> tuple[torch.Tensor]:
  251. self_attention_outputs = self.attention(
  252. hidden_states,
  253. attention_mask=attention_mask,
  254. output_attentions=output_attentions,
  255. past_key_values=past_key_values,
  256. )
  257. attention_output = self_attention_outputs[0]
  258. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  259. if self.is_decoder and encoder_hidden_states is not None:
  260. if not hasattr(self, "crossattention"):
  261. raise ValueError(
  262. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  263. " by setting `config.add_cross_attention=True`"
  264. )
  265. cross_attention_outputs = self.crossattention(
  266. attention_output,
  267. attention_mask=encoder_attention_mask,
  268. encoder_hidden_states=encoder_hidden_states,
  269. past_key_values=past_key_values,
  270. output_attentions=output_attentions,
  271. )
  272. attention_output = cross_attention_outputs[0]
  273. outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
  274. layer_output = apply_chunking_to_forward(
  275. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  276. )
  277. outputs = (layer_output,) + outputs
  278. return outputs
  279. # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
  280. def feed_forward_chunk(self, attention_output):
  281. intermediate_output = self.intermediate(attention_output)
  282. layer_output = self.output(intermediate_output, attention_output)
  283. return layer_output
  284. class RemBertEncoder(nn.Module):
  285. def __init__(self, config):
  286. super().__init__()
  287. self.config = config
  288. self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size)
  289. self.layer = nn.ModuleList([RemBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  290. self.gradient_checkpointing = False
  291. def forward(
  292. self,
  293. hidden_states: torch.Tensor,
  294. attention_mask: torch.FloatTensor | None = None,
  295. encoder_hidden_states: torch.FloatTensor | None = None,
  296. encoder_attention_mask: torch.FloatTensor | None = None,
  297. past_key_values: Cache | None = None,
  298. use_cache: bool | None = None,
  299. output_attentions: bool = False,
  300. output_hidden_states: bool = False,
  301. return_dict: bool = True,
  302. **kwargs,
  303. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  304. if self.gradient_checkpointing and self.training:
  305. if use_cache:
  306. logger.warning_once(
  307. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  308. )
  309. use_cache = False
  310. if use_cache and past_key_values is None:
  311. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  312. hidden_states = self.embedding_hidden_mapping_in(hidden_states)
  313. all_hidden_states = () if output_hidden_states else None
  314. all_self_attentions = () if output_attentions else None
  315. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  316. for i, layer_module in enumerate(self.layer):
  317. if output_hidden_states:
  318. all_hidden_states = all_hidden_states + (hidden_states,)
  319. layer_outputs = layer_module(
  320. hidden_states,
  321. attention_mask,
  322. encoder_hidden_states,
  323. encoder_attention_mask,
  324. past_key_values,
  325. output_attentions,
  326. )
  327. hidden_states = layer_outputs[0]
  328. if output_attentions:
  329. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  330. if self.config.add_cross_attention:
  331. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  332. if output_hidden_states:
  333. all_hidden_states = all_hidden_states + (hidden_states,)
  334. if not return_dict:
  335. return tuple(
  336. v
  337. for v in [
  338. hidden_states,
  339. past_key_values,
  340. all_hidden_states,
  341. all_self_attentions,
  342. all_cross_attentions,
  343. ]
  344. if v is not None
  345. )
  346. return BaseModelOutputWithPastAndCrossAttentions(
  347. last_hidden_state=hidden_states,
  348. past_key_values=past_key_values,
  349. hidden_states=all_hidden_states,
  350. attentions=all_self_attentions,
  351. cross_attentions=all_cross_attentions,
  352. )
  353. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->RemBert
  354. class RemBertPredictionHeadTransform(nn.Module):
  355. def __init__(self, config):
  356. super().__init__()
  357. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  358. if isinstance(config.hidden_act, str):
  359. self.transform_act_fn = ACT2FN[config.hidden_act]
  360. else:
  361. self.transform_act_fn = config.hidden_act
  362. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  363. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  364. hidden_states = self.dense(hidden_states)
  365. hidden_states = self.transform_act_fn(hidden_states)
  366. hidden_states = self.LayerNorm(hidden_states)
  367. return hidden_states
  368. class RemBertLMPredictionHead(nn.Module):
  369. def __init__(self, config):
  370. super().__init__()
  371. self.dense = nn.Linear(config.hidden_size, config.output_embedding_size)
  372. self.decoder = nn.Linear(config.output_embedding_size, config.vocab_size)
  373. self.activation = ACT2FN[config.hidden_act]
  374. self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps)
  375. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  376. hidden_states = self.dense(hidden_states)
  377. hidden_states = self.activation(hidden_states)
  378. hidden_states = self.LayerNorm(hidden_states)
  379. hidden_states = self.decoder(hidden_states)
  380. return hidden_states
  381. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RemBert
  382. class RemBertOnlyMLMHead(nn.Module):
  383. def __init__(self, config):
  384. super().__init__()
  385. self.predictions = RemBertLMPredictionHead(config)
  386. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  387. prediction_scores = self.predictions(sequence_output)
  388. return prediction_scores
  389. @auto_docstring
  390. class RemBertPreTrainedModel(PreTrainedModel):
  391. config: RemBertConfig
  392. base_model_prefix = "rembert"
  393. supports_gradient_checkpointing = True
  394. def _init_weights(self, module):
  395. super()._init_weights(module)
  396. if isinstance(module, RemBertEmbeddings):
  397. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  398. @auto_docstring(
  399. custom_intro="""
  400. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  401. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  402. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  403. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  404. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  405. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  406. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  407. """
  408. )
  409. class RemBertModel(RemBertPreTrainedModel):
  410. def __init__(self, config, add_pooling_layer=True):
  411. r"""
  412. add_pooling_layer (bool, *optional*, defaults to `True`):
  413. Whether to add a pooling layer
  414. """
  415. super().__init__(config)
  416. self.config = config
  417. self.embeddings = RemBertEmbeddings(config)
  418. self.encoder = RemBertEncoder(config)
  419. self.pooler = RemBertPooler(config) if add_pooling_layer else None
  420. # Initialize weights and apply final processing
  421. self.post_init()
  422. def get_input_embeddings(self):
  423. return self.embeddings.word_embeddings
  424. def set_input_embeddings(self, value):
  425. self.embeddings.word_embeddings = value
  426. @auto_docstring
  427. def forward(
  428. self,
  429. input_ids: torch.LongTensor | None = None,
  430. attention_mask: torch.LongTensor | None = None,
  431. token_type_ids: torch.LongTensor | None = None,
  432. position_ids: torch.LongTensor | None = None,
  433. inputs_embeds: torch.FloatTensor | None = None,
  434. encoder_hidden_states: torch.FloatTensor | None = None,
  435. encoder_attention_mask: torch.FloatTensor | None = None,
  436. past_key_values: Cache | None = None,
  437. use_cache: bool | None = None,
  438. output_attentions: bool | None = None,
  439. output_hidden_states: bool | None = None,
  440. return_dict: bool | None = None,
  441. **kwargs,
  442. ) -> tuple | BaseModelOutputWithPoolingAndCrossAttentions:
  443. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  444. output_hidden_states = (
  445. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  446. )
  447. return_dict = return_dict if return_dict is not None else self.config.return_dict
  448. if self.config.is_decoder:
  449. use_cache = use_cache if use_cache is not None else self.config.use_cache
  450. else:
  451. use_cache = False
  452. if input_ids is not None and inputs_embeds is not None:
  453. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  454. elif input_ids is not None:
  455. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  456. input_shape = input_ids.size()
  457. elif inputs_embeds is not None:
  458. input_shape = inputs_embeds.size()[:-1]
  459. else:
  460. raise ValueError("You have to specify either input_ids or inputs_embeds")
  461. batch_size, seq_length = input_shape
  462. device = input_ids.device if input_ids is not None else inputs_embeds.device
  463. past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
  464. if attention_mask is None:
  465. attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
  466. if token_type_ids is None:
  467. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  468. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  469. # ourselves in which case we just need to make it broadcastable to all heads.
  470. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  471. # If a 2D or 3D attention mask is provided for the cross-attention
  472. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  473. if self.config.is_decoder and encoder_hidden_states is not None:
  474. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  475. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  476. if encoder_attention_mask is None:
  477. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  478. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  479. else:
  480. encoder_extended_attention_mask = None
  481. embedding_output = self.embeddings(
  482. input_ids=input_ids,
  483. position_ids=position_ids,
  484. token_type_ids=token_type_ids,
  485. inputs_embeds=inputs_embeds,
  486. past_key_values_length=past_key_values_length,
  487. )
  488. encoder_outputs = self.encoder(
  489. embedding_output,
  490. attention_mask=extended_attention_mask,
  491. encoder_hidden_states=encoder_hidden_states,
  492. encoder_attention_mask=encoder_extended_attention_mask,
  493. past_key_values=past_key_values,
  494. use_cache=use_cache,
  495. output_attentions=output_attentions,
  496. output_hidden_states=output_hidden_states,
  497. return_dict=return_dict,
  498. )
  499. sequence_output = encoder_outputs[0]
  500. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  501. if not return_dict:
  502. return (sequence_output, pooled_output) + encoder_outputs[1:]
  503. return BaseModelOutputWithPoolingAndCrossAttentions(
  504. last_hidden_state=sequence_output,
  505. pooler_output=pooled_output,
  506. past_key_values=encoder_outputs.past_key_values,
  507. hidden_states=encoder_outputs.hidden_states,
  508. attentions=encoder_outputs.attentions,
  509. cross_attentions=encoder_outputs.cross_attentions,
  510. )
  511. @auto_docstring
  512. class RemBertForMaskedLM(RemBertPreTrainedModel):
  513. def __init__(self, config):
  514. super().__init__(config)
  515. if config.is_decoder:
  516. logger.warning(
  517. "If you want to use `RemBertForMaskedLM` make sure `config.is_decoder=False` for "
  518. "bi-directional self-attention."
  519. )
  520. self.rembert = RemBertModel(config, add_pooling_layer=False)
  521. self.cls = RemBertOnlyMLMHead(config)
  522. # Initialize weights and apply final processing
  523. self.post_init()
  524. def get_output_embeddings(self):
  525. return self.cls.predictions.decoder
  526. def set_output_embeddings(self, new_embeddings):
  527. self.cls.predictions.decoder = new_embeddings
  528. @auto_docstring
  529. def forward(
  530. self,
  531. input_ids: torch.LongTensor | None = None,
  532. attention_mask: torch.LongTensor | None = None,
  533. token_type_ids: torch.LongTensor | None = None,
  534. position_ids: torch.LongTensor | None = None,
  535. inputs_embeds: torch.FloatTensor | None = None,
  536. encoder_hidden_states: torch.FloatTensor | None = None,
  537. encoder_attention_mask: torch.FloatTensor | None = None,
  538. labels: torch.LongTensor | None = None,
  539. output_attentions: bool | None = None,
  540. output_hidden_states: bool | None = None,
  541. return_dict: bool | None = None,
  542. **kwargs,
  543. ) -> tuple | MaskedLMOutput:
  544. r"""
  545. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  546. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  547. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  548. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  549. """
  550. return_dict = return_dict if return_dict is not None else self.config.return_dict
  551. outputs = self.rembert(
  552. input_ids,
  553. attention_mask=attention_mask,
  554. token_type_ids=token_type_ids,
  555. position_ids=position_ids,
  556. inputs_embeds=inputs_embeds,
  557. encoder_hidden_states=encoder_hidden_states,
  558. encoder_attention_mask=encoder_attention_mask,
  559. output_attentions=output_attentions,
  560. output_hidden_states=output_hidden_states,
  561. return_dict=return_dict,
  562. )
  563. sequence_output = outputs[0]
  564. prediction_scores = self.cls(sequence_output)
  565. masked_lm_loss = None
  566. if labels is not None:
  567. loss_fct = CrossEntropyLoss() # -100 index = padding token
  568. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  569. if not return_dict:
  570. output = (prediction_scores,) + outputs[2:]
  571. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  572. return MaskedLMOutput(
  573. loss=masked_lm_loss,
  574. logits=prediction_scores,
  575. hidden_states=outputs.hidden_states,
  576. attentions=outputs.attentions,
  577. )
  578. @auto_docstring(
  579. custom_intro="""
  580. RemBERT Model with a `language modeling` head on top for CLM fine-tuning.
  581. """
  582. )
  583. class RemBertForCausalLM(RemBertPreTrainedModel, GenerationMixin):
  584. def __init__(self, config):
  585. super().__init__(config)
  586. if not config.is_decoder:
  587. logger.warning("If you want to use `RemBertForCausalLM` as a standalone, add `is_decoder=True.`")
  588. self.rembert = RemBertModel(config, add_pooling_layer=False)
  589. self.cls = RemBertOnlyMLMHead(config)
  590. # Initialize weights and apply final processing
  591. self.post_init()
  592. def get_output_embeddings(self):
  593. return self.cls.predictions.decoder
  594. def set_output_embeddings(self, new_embeddings):
  595. self.cls.predictions.decoder = new_embeddings
  596. @auto_docstring
  597. def forward(
  598. self,
  599. input_ids: torch.LongTensor | None = None,
  600. attention_mask: torch.LongTensor | None = None,
  601. token_type_ids: torch.LongTensor | None = None,
  602. position_ids: torch.LongTensor | None = None,
  603. inputs_embeds: torch.FloatTensor | None = None,
  604. encoder_hidden_states: torch.FloatTensor | None = None,
  605. encoder_attention_mask: torch.FloatTensor | None = None,
  606. past_key_values: Cache | None = None,
  607. labels: torch.LongTensor | None = None,
  608. use_cache: bool | None = None,
  609. output_attentions: bool | None = None,
  610. output_hidden_states: bool | None = None,
  611. return_dict: bool | None = None,
  612. logits_to_keep: int | torch.Tensor = 0,
  613. **kwargs,
  614. ) -> tuple | CausalLMOutputWithCrossAttentions:
  615. r"""
  616. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  617. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  618. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  619. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
  620. Example:
  621. ```python
  622. >>> from transformers import AutoTokenizer, RemBertForCausalLM, RemBertConfig
  623. >>> import torch
  624. >>> tokenizer = AutoTokenizer.from_pretrained("google/rembert")
  625. >>> config = RemBertConfig.from_pretrained("google/rembert")
  626. >>> config.is_decoder = True
  627. >>> model = RemBertForCausalLM.from_pretrained("google/rembert", config=config)
  628. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  629. >>> outputs = model(**inputs)
  630. >>> prediction_logits = outputs.logits
  631. ```"""
  632. return_dict = return_dict if return_dict is not None else self.config.return_dict
  633. outputs = self.rembert(
  634. input_ids,
  635. attention_mask=attention_mask,
  636. token_type_ids=token_type_ids,
  637. position_ids=position_ids,
  638. inputs_embeds=inputs_embeds,
  639. encoder_hidden_states=encoder_hidden_states,
  640. encoder_attention_mask=encoder_attention_mask,
  641. past_key_values=past_key_values,
  642. use_cache=use_cache,
  643. output_attentions=output_attentions,
  644. output_hidden_states=output_hidden_states,
  645. return_dict=return_dict,
  646. )
  647. hidden_states = outputs[0]
  648. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  649. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  650. logits = self.cls(hidden_states[:, slice_indices, :])
  651. loss = None
  652. if labels is not None:
  653. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  654. if not return_dict:
  655. output = (logits,) + outputs[2:]
  656. return ((loss,) + output) if loss is not None else output
  657. return CausalLMOutputWithCrossAttentions(
  658. loss=loss,
  659. logits=logits,
  660. past_key_values=outputs.past_key_values,
  661. hidden_states=outputs.hidden_states,
  662. attentions=outputs.attentions,
  663. cross_attentions=outputs.cross_attentions,
  664. )
  665. @auto_docstring(
  666. custom_intro="""
  667. RemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  668. pooled output) e.g. for GLUE tasks.
  669. """
  670. )
  671. class RemBertForSequenceClassification(RemBertPreTrainedModel):
  672. def __init__(self, config):
  673. super().__init__(config)
  674. self.num_labels = config.num_labels
  675. self.rembert = RemBertModel(config)
  676. self.dropout = nn.Dropout(config.classifier_dropout_prob)
  677. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  678. # Initialize weights and apply final processing
  679. self.post_init()
  680. @auto_docstring
  681. def forward(
  682. self,
  683. input_ids: torch.FloatTensor | None = None,
  684. attention_mask: torch.FloatTensor | None = None,
  685. token_type_ids: torch.LongTensor | None = None,
  686. position_ids: torch.FloatTensor | None = None,
  687. inputs_embeds: torch.FloatTensor | None = None,
  688. labels: torch.LongTensor | None = None,
  689. output_attentions: bool | None = None,
  690. output_hidden_states: bool | None = None,
  691. return_dict: bool | None = None,
  692. **kwargs,
  693. ) -> tuple | SequenceClassifierOutput:
  694. r"""
  695. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  696. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  697. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  698. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  699. """
  700. return_dict = return_dict if return_dict is not None else self.config.return_dict
  701. outputs = self.rembert(
  702. input_ids,
  703. attention_mask=attention_mask,
  704. token_type_ids=token_type_ids,
  705. position_ids=position_ids,
  706. inputs_embeds=inputs_embeds,
  707. output_attentions=output_attentions,
  708. output_hidden_states=output_hidden_states,
  709. return_dict=return_dict,
  710. )
  711. pooled_output = outputs[1]
  712. pooled_output = self.dropout(pooled_output)
  713. logits = self.classifier(pooled_output)
  714. loss = None
  715. if labels is not None:
  716. if self.config.problem_type is None:
  717. if self.num_labels == 1:
  718. self.config.problem_type = "regression"
  719. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  720. self.config.problem_type = "single_label_classification"
  721. else:
  722. self.config.problem_type = "multi_label_classification"
  723. if self.config.problem_type == "regression":
  724. loss_fct = MSELoss()
  725. if self.num_labels == 1:
  726. loss = loss_fct(logits.squeeze(), labels.squeeze())
  727. else:
  728. loss = loss_fct(logits, labels)
  729. elif self.config.problem_type == "single_label_classification":
  730. loss_fct = CrossEntropyLoss()
  731. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  732. elif self.config.problem_type == "multi_label_classification":
  733. loss_fct = BCEWithLogitsLoss()
  734. loss = loss_fct(logits, labels)
  735. if not return_dict:
  736. output = (logits,) + outputs[2:]
  737. return ((loss,) + output) if loss is not None else output
  738. return SequenceClassifierOutput(
  739. loss=loss,
  740. logits=logits,
  741. hidden_states=outputs.hidden_states,
  742. attentions=outputs.attentions,
  743. )
  744. @auto_docstring
  745. class RemBertForMultipleChoice(RemBertPreTrainedModel):
  746. def __init__(self, config):
  747. super().__init__(config)
  748. self.rembert = RemBertModel(config)
  749. self.dropout = nn.Dropout(config.classifier_dropout_prob)
  750. self.classifier = nn.Linear(config.hidden_size, 1)
  751. # Initialize weights and apply final processing
  752. self.post_init()
  753. @auto_docstring
  754. def forward(
  755. self,
  756. input_ids: torch.FloatTensor | None = None,
  757. attention_mask: torch.FloatTensor | None = None,
  758. token_type_ids: torch.LongTensor | None = None,
  759. position_ids: torch.FloatTensor | None = None,
  760. inputs_embeds: torch.FloatTensor | None = None,
  761. labels: torch.LongTensor | None = None,
  762. output_attentions: bool | None = None,
  763. output_hidden_states: bool | None = None,
  764. return_dict: bool | None = None,
  765. **kwargs,
  766. ) -> tuple | MultipleChoiceModelOutput:
  767. r"""
  768. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  769. Indices of input sequence tokens in the vocabulary.
  770. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  771. [`PreTrainedTokenizer.__call__`] for details.
  772. [What are input IDs?](../glossary#input-ids)
  773. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  774. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  775. 1]`:
  776. - 0 corresponds to a *sentence A* token,
  777. - 1 corresponds to a *sentence B* token.
  778. [What are token type IDs?](../glossary#token-type-ids)
  779. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  780. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  781. config.max_position_embeddings - 1]`.
  782. [What are position IDs?](../glossary#position-ids)
  783. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  784. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  785. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  786. model's internal embedding lookup matrix.
  787. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  788. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  789. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  790. `input_ids` above)
  791. """
  792. return_dict = return_dict if return_dict is not None else self.config.return_dict
  793. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  794. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  795. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  796. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  797. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  798. inputs_embeds = (
  799. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  800. if inputs_embeds is not None
  801. else None
  802. )
  803. outputs = self.rembert(
  804. input_ids,
  805. attention_mask=attention_mask,
  806. token_type_ids=token_type_ids,
  807. position_ids=position_ids,
  808. inputs_embeds=inputs_embeds,
  809. output_attentions=output_attentions,
  810. output_hidden_states=output_hidden_states,
  811. return_dict=return_dict,
  812. )
  813. pooled_output = outputs[1]
  814. pooled_output = self.dropout(pooled_output)
  815. logits = self.classifier(pooled_output)
  816. reshaped_logits = logits.view(-1, num_choices)
  817. loss = None
  818. if labels is not None:
  819. loss_fct = CrossEntropyLoss()
  820. loss = loss_fct(reshaped_logits, labels)
  821. if not return_dict:
  822. output = (reshaped_logits,) + outputs[2:]
  823. return ((loss,) + output) if loss is not None else output
  824. return MultipleChoiceModelOutput(
  825. loss=loss,
  826. logits=reshaped_logits,
  827. hidden_states=outputs.hidden_states,
  828. attentions=outputs.attentions,
  829. )
  830. @auto_docstring
  831. class RemBertForTokenClassification(RemBertPreTrainedModel):
  832. def __init__(self, config):
  833. super().__init__(config)
  834. self.num_labels = config.num_labels
  835. self.rembert = RemBertModel(config, add_pooling_layer=False)
  836. self.dropout = nn.Dropout(config.classifier_dropout_prob)
  837. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  838. # Initialize weights and apply final processing
  839. self.post_init()
  840. @auto_docstring
  841. def forward(
  842. self,
  843. input_ids: torch.FloatTensor | None = None,
  844. attention_mask: torch.FloatTensor | None = None,
  845. token_type_ids: torch.LongTensor | None = None,
  846. position_ids: torch.FloatTensor | None = None,
  847. inputs_embeds: torch.FloatTensor | None = None,
  848. labels: torch.LongTensor | None = None,
  849. output_attentions: bool | None = None,
  850. output_hidden_states: bool | None = None,
  851. return_dict: bool | None = None,
  852. **kwargs,
  853. ) -> tuple | TokenClassifierOutput:
  854. r"""
  855. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  856. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  857. """
  858. return_dict = return_dict if return_dict is not None else self.config.return_dict
  859. outputs = self.rembert(
  860. input_ids,
  861. attention_mask=attention_mask,
  862. token_type_ids=token_type_ids,
  863. position_ids=position_ids,
  864. inputs_embeds=inputs_embeds,
  865. output_attentions=output_attentions,
  866. output_hidden_states=output_hidden_states,
  867. return_dict=return_dict,
  868. )
  869. sequence_output = outputs[0]
  870. sequence_output = self.dropout(sequence_output)
  871. logits = self.classifier(sequence_output)
  872. loss = None
  873. if labels is not None:
  874. loss_fct = CrossEntropyLoss()
  875. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  876. if not return_dict:
  877. output = (logits,) + outputs[2:]
  878. return ((loss,) + output) if loss is not None else output
  879. return TokenClassifierOutput(
  880. loss=loss,
  881. logits=logits,
  882. hidden_states=outputs.hidden_states,
  883. attentions=outputs.attentions,
  884. )
  885. @auto_docstring
  886. class RemBertForQuestionAnswering(RemBertPreTrainedModel):
  887. def __init__(self, config):
  888. super().__init__(config)
  889. self.num_labels = config.num_labels
  890. self.rembert = RemBertModel(config, add_pooling_layer=False)
  891. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  892. # Initialize weights and apply final processing
  893. self.post_init()
  894. @auto_docstring
  895. def forward(
  896. self,
  897. input_ids: torch.FloatTensor | None = None,
  898. attention_mask: torch.FloatTensor | None = None,
  899. token_type_ids: torch.LongTensor | None = None,
  900. position_ids: torch.FloatTensor | None = None,
  901. inputs_embeds: torch.FloatTensor | None = None,
  902. start_positions: torch.LongTensor | None = None,
  903. end_positions: torch.LongTensor | None = None,
  904. output_attentions: bool | None = None,
  905. output_hidden_states: bool | None = None,
  906. return_dict: bool | None = None,
  907. **kwargs,
  908. ) -> tuple | QuestionAnsweringModelOutput:
  909. return_dict = return_dict if return_dict is not None else self.config.return_dict
  910. outputs = self.rembert(
  911. input_ids,
  912. attention_mask=attention_mask,
  913. token_type_ids=token_type_ids,
  914. position_ids=position_ids,
  915. inputs_embeds=inputs_embeds,
  916. output_attentions=output_attentions,
  917. output_hidden_states=output_hidden_states,
  918. return_dict=return_dict,
  919. )
  920. sequence_output = outputs[0]
  921. logits = self.qa_outputs(sequence_output)
  922. start_logits, end_logits = logits.split(1, dim=-1)
  923. start_logits = start_logits.squeeze(-1)
  924. end_logits = end_logits.squeeze(-1)
  925. total_loss = None
  926. if start_positions is not None and end_positions is not None:
  927. # If we are on multi-GPU, split add a dimension
  928. if len(start_positions.size()) > 1:
  929. start_positions = start_positions.squeeze(-1)
  930. if len(end_positions.size()) > 1:
  931. end_positions = end_positions.squeeze(-1)
  932. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  933. ignored_index = start_logits.size(1)
  934. start_positions.clamp_(0, ignored_index)
  935. end_positions.clamp_(0, ignored_index)
  936. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  937. start_loss = loss_fct(start_logits, start_positions)
  938. end_loss = loss_fct(end_logits, end_positions)
  939. total_loss = (start_loss + end_loss) / 2
  940. if not return_dict:
  941. output = (start_logits, end_logits) + outputs[2:]
  942. return ((total_loss,) + output) if total_loss is not None else output
  943. return QuestionAnsweringModelOutput(
  944. loss=total_loss,
  945. start_logits=start_logits,
  946. end_logits=end_logits,
  947. hidden_states=outputs.hidden_states,
  948. attentions=outputs.attentions,
  949. )
  950. __all__ = [
  951. "RemBertForCausalLM",
  952. "RemBertForMaskedLM",
  953. "RemBertForMultipleChoice",
  954. "RemBertForQuestionAnswering",
  955. "RemBertForSequenceClassification",
  956. "RemBertForTokenClassification",
  957. "RemBertLayer",
  958. "RemBertModel",
  959. "RemBertPreTrainedModel",
  960. ]