modeling_deberta.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190
  1. # Copyright 2020 Microsoft and the Hugging Face Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch DeBERTa model."""
  15. import torch
  16. from torch import nn
  17. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...modeling_layers import GradientCheckpointingLayer
  21. from ...modeling_outputs import (
  22. BaseModelOutput,
  23. MaskedLMOutput,
  24. QuestionAnsweringModelOutput,
  25. SequenceClassifierOutput,
  26. TokenClassifierOutput,
  27. )
  28. from ...modeling_utils import PreTrainedModel
  29. from ...utils import auto_docstring, logging
  30. from .configuration_deberta import DebertaConfig
  31. logger = logging.get_logger(__name__)
  32. class DebertaLayerNorm(nn.Module):
  33. """LayerNorm module (epsilon inside the square root)."""
  34. def __init__(self, size, eps=1e-12):
  35. super().__init__()
  36. self.weight = nn.Parameter(torch.ones(size))
  37. self.bias = nn.Parameter(torch.zeros(size))
  38. self.variance_epsilon = eps
  39. def forward(self, hidden_states):
  40. input_type = hidden_states.dtype
  41. hidden_states = hidden_states.float()
  42. mean = hidden_states.mean(-1, keepdim=True)
  43. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  44. hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon)
  45. hidden_states = hidden_states.to(input_type)
  46. y = self.weight * hidden_states + self.bias
  47. return y
  48. class DebertaSelfOutput(nn.Module):
  49. def __init__(self, config):
  50. super().__init__()
  51. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  52. self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
  53. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  54. def forward(self, hidden_states, input_tensor):
  55. hidden_states = self.dense(hidden_states)
  56. hidden_states = self.dropout(hidden_states)
  57. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  58. return hidden_states
  59. @torch.jit.script
  60. def build_relative_position(query_layer, key_layer):
  61. """
  62. Build relative position according to the query and key
  63. We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
  64. \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
  65. P_k\\)
  66. Args:
  67. query_size (int): the length of query
  68. key_size (int): the length of key
  69. Return:
  70. `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
  71. """
  72. query_size = query_layer.size(-2)
  73. key_size = key_layer.size(-2)
  74. q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device)
  75. k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device)
  76. rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
  77. rel_pos_ids = rel_pos_ids[:query_size, :]
  78. rel_pos_ids = rel_pos_ids.unsqueeze(0)
  79. return rel_pos_ids
  80. @torch.jit.script
  81. def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
  82. return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
  83. @torch.jit.script
  84. def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
  85. return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
  86. @torch.jit.script
  87. def pos_dynamic_expand(pos_index, p2c_att, key_layer):
  88. return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
  89. ###### To support a general trace, we have to define these operation as they use python objects (sizes) ##################
  90. # which are not supported by torch.jit.trace.
  91. # Full credits to @Szustarol
  92. @torch.jit.script
  93. def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int):
  94. return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
  95. @torch.jit.script
  96. def build_rpos(query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos):
  97. if query_layer.size(-2) != key_layer.size(-2):
  98. return build_relative_position(query_layer, key_layer)
  99. else:
  100. return relative_pos
  101. @torch.jit.script
  102. def compute_attention_span(query_layer: torch.Tensor, key_layer: torch.Tensor, max_relative_positions: int):
  103. return torch.tensor(min(max(query_layer.size(-2), key_layer.size(-2)), max_relative_positions))
  104. @torch.jit.script
  105. def uneven_size_corrected(p2c_att, query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos):
  106. if query_layer.size(-2) != key_layer.size(-2):
  107. pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
  108. return torch.gather(p2c_att, dim=2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
  109. else:
  110. return p2c_att
  111. ########################################################################################################################
  112. class DisentangledSelfAttention(nn.Module):
  113. """
  114. Disentangled self-attention module
  115. Parameters:
  116. config (`str`):
  117. A model config class instance with the configuration to build a new model. The schema is similar to
  118. *BertConfig*, for more details, please refer [`DebertaConfig`]
  119. """
  120. def __init__(self, config):
  121. super().__init__()
  122. if config.hidden_size % config.num_attention_heads != 0:
  123. raise ValueError(
  124. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  125. f"heads ({config.num_attention_heads})"
  126. )
  127. self.num_attention_heads = config.num_attention_heads
  128. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  129. self.all_head_size = self.num_attention_heads * self.attention_head_size
  130. self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
  131. self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
  132. self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
  133. self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
  134. self.relative_attention = getattr(config, "relative_attention", False)
  135. self.talking_head = getattr(config, "talking_head", False)
  136. if self.talking_head:
  137. self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
  138. self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
  139. else:
  140. self.head_logits_proj = None
  141. self.head_weights_proj = None
  142. if self.relative_attention:
  143. self.max_relative_positions = getattr(config, "max_relative_positions", -1)
  144. if self.max_relative_positions < 1:
  145. self.max_relative_positions = config.max_position_embeddings
  146. self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
  147. if "c2p" in self.pos_att_type:
  148. self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  149. if "p2c" in self.pos_att_type:
  150. self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
  151. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  152. def transpose_for_scores(self, x):
  153. new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
  154. x = x.view(new_x_shape)
  155. return x.permute(0, 2, 1, 3)
  156. def forward(
  157. self,
  158. hidden_states: torch.Tensor,
  159. attention_mask: torch.Tensor,
  160. output_attentions: bool = False,
  161. query_states: torch.Tensor | None = None,
  162. relative_pos: torch.Tensor | None = None,
  163. rel_embeddings: torch.Tensor | None = None,
  164. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  165. """
  166. Call the module
  167. Args:
  168. hidden_states (`torch.FloatTensor`):
  169. Input states to the module usually the output from previous layer, it will be the Q,K and V in
  170. *Attention(Q,K,V)*
  171. attention_mask (`torch.BoolTensor`):
  172. An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
  173. sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
  174. th token.
  175. output_attentions (`bool`, *optional*):
  176. Whether return the attention matrix.
  177. query_states (`torch.FloatTensor`, *optional*):
  178. The *Q* state in *Attention(Q,K,V)*.
  179. relative_pos (`torch.LongTensor`):
  180. The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
  181. values ranging in [*-max_relative_positions*, *max_relative_positions*].
  182. rel_embeddings (`torch.FloatTensor`):
  183. The embedding of relative distances. It's a tensor of shape [\\(2 \\times
  184. \\text{max_relative_positions}\\), *hidden_size*].
  185. """
  186. if query_states is None:
  187. qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
  188. query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
  189. else:
  190. ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0)
  191. qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
  192. q = torch.matmul(qkvw[0], query_states.t().to(dtype=qkvw[0].dtype))
  193. k = torch.matmul(qkvw[1], hidden_states.t().to(dtype=qkvw[1].dtype))
  194. v = torch.matmul(qkvw[2], hidden_states.t().to(dtype=qkvw[2].dtype))
  195. query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
  196. query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
  197. value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
  198. rel_att: int = 0
  199. # Take the dot product between "query" and "key" to get the raw attention scores.
  200. scale_factor = 1 + len(self.pos_att_type)
  201. scale = scaled_size_sqrt(query_layer, scale_factor)
  202. query_layer = query_layer / scale.to(dtype=query_layer.dtype)
  203. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  204. if self.relative_attention and rel_embeddings is not None and relative_pos is not None:
  205. rel_embeddings = self.pos_dropout(rel_embeddings)
  206. rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
  207. if rel_att is not None:
  208. attention_scores = attention_scores + rel_att
  209. # bxhxlxd
  210. if self.head_logits_proj is not None:
  211. attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  212. attention_mask = attention_mask.bool()
  213. attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
  214. # bsz x height x length x dimension
  215. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  216. attention_probs = self.dropout(attention_probs)
  217. if self.head_weights_proj is not None:
  218. attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  219. context_layer = torch.matmul(attention_probs, value_layer)
  220. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  221. new_context_layer_shape = context_layer.size()[:-2] + (-1,)
  222. context_layer = context_layer.view(new_context_layer_shape)
  223. if not output_attentions:
  224. return (context_layer, None)
  225. return (context_layer, attention_probs)
  226. def disentangled_att_bias(
  227. self,
  228. query_layer: torch.Tensor,
  229. key_layer: torch.Tensor,
  230. relative_pos: torch.Tensor,
  231. rel_embeddings: torch.Tensor,
  232. scale_factor: int,
  233. ):
  234. if relative_pos is None:
  235. relative_pos = build_relative_position(query_layer, key_layer, query_layer.device)
  236. if relative_pos.dim() == 2:
  237. relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
  238. elif relative_pos.dim() == 3:
  239. relative_pos = relative_pos.unsqueeze(1)
  240. # bxhxqxk
  241. elif relative_pos.dim() != 4:
  242. raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
  243. att_span = compute_attention_span(query_layer, key_layer, self.max_relative_positions)
  244. relative_pos = relative_pos.long()
  245. rel_embeddings = rel_embeddings[
  246. self.max_relative_positions - att_span : self.max_relative_positions + att_span, :
  247. ].unsqueeze(0)
  248. score = 0
  249. # content->position
  250. if "c2p" in self.pos_att_type:
  251. pos_key_layer = self.pos_proj(rel_embeddings)
  252. pos_key_layer = self.transpose_for_scores(pos_key_layer)
  253. c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
  254. c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
  255. c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
  256. score += c2p_att
  257. # position->content
  258. if "p2c" in self.pos_att_type:
  259. pos_query_layer = self.pos_q_proj(rel_embeddings)
  260. pos_query_layer = self.transpose_for_scores(pos_query_layer)
  261. pos_query_layer /= scaled_size_sqrt(pos_query_layer, scale_factor)
  262. r_pos = build_rpos(
  263. query_layer,
  264. key_layer,
  265. relative_pos,
  266. )
  267. p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
  268. p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype))
  269. p2c_att = torch.gather(
  270. p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
  271. ).transpose(-1, -2)
  272. p2c_att = uneven_size_corrected(p2c_att, query_layer, key_layer, relative_pos)
  273. score += p2c_att
  274. return score
  275. class DebertaEmbeddings(nn.Module):
  276. """Construct the embeddings from word, position and token_type embeddings."""
  277. def __init__(self, config):
  278. super().__init__()
  279. pad_token_id = getattr(config, "pad_token_id", 0)
  280. self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
  281. self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
  282. self.position_biased_input = getattr(config, "position_biased_input", True)
  283. if not self.position_biased_input:
  284. self.position_embeddings = None
  285. else:
  286. self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
  287. if config.type_vocab_size > 0:
  288. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
  289. else:
  290. self.token_type_embeddings = None
  291. if self.embedding_size != config.hidden_size:
  292. self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
  293. else:
  294. self.embed_proj = None
  295. self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
  296. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  297. self.config = config
  298. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  299. self.register_buffer(
  300. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  301. )
  302. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
  303. if input_ids is not None:
  304. input_shape = input_ids.size()
  305. else:
  306. input_shape = inputs_embeds.size()[:-1]
  307. seq_length = input_shape[1]
  308. if position_ids is None:
  309. position_ids = self.position_ids[:, :seq_length]
  310. if token_type_ids is None:
  311. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  312. if inputs_embeds is None:
  313. inputs_embeds = self.word_embeddings(input_ids)
  314. if self.position_embeddings is not None:
  315. position_embeddings = self.position_embeddings(position_ids.long())
  316. else:
  317. position_embeddings = torch.zeros_like(inputs_embeds)
  318. embeddings = inputs_embeds
  319. if self.position_biased_input:
  320. embeddings = embeddings + position_embeddings
  321. if self.token_type_embeddings is not None:
  322. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  323. embeddings = embeddings + token_type_embeddings
  324. if self.embed_proj is not None:
  325. embeddings = self.embed_proj(embeddings)
  326. embeddings = self.LayerNorm(embeddings)
  327. if mask is not None:
  328. if mask.dim() != embeddings.dim():
  329. if mask.dim() == 4:
  330. mask = mask.squeeze(1).squeeze(1)
  331. mask = mask.unsqueeze(2)
  332. mask = mask.to(embeddings.dtype)
  333. embeddings = embeddings * mask
  334. embeddings = self.dropout(embeddings)
  335. return embeddings
  336. class DebertaAttention(nn.Module):
  337. def __init__(self, config):
  338. super().__init__()
  339. self.self = DisentangledSelfAttention(config)
  340. self.output = DebertaSelfOutput(config)
  341. self.config = config
  342. def forward(
  343. self,
  344. hidden_states,
  345. attention_mask,
  346. output_attentions: bool = False,
  347. query_states=None,
  348. relative_pos=None,
  349. rel_embeddings=None,
  350. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  351. self_output, att_matrix = self.self(
  352. hidden_states,
  353. attention_mask,
  354. output_attentions,
  355. query_states=query_states,
  356. relative_pos=relative_pos,
  357. rel_embeddings=rel_embeddings,
  358. )
  359. if query_states is None:
  360. query_states = hidden_states
  361. attention_output = self.output(self_output, query_states)
  362. if output_attentions:
  363. return (attention_output, att_matrix)
  364. else:
  365. return (attention_output, None)
  366. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta
  367. class DebertaIntermediate(nn.Module):
  368. def __init__(self, config):
  369. super().__init__()
  370. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  371. if isinstance(config.hidden_act, str):
  372. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  373. else:
  374. self.intermediate_act_fn = config.hidden_act
  375. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  376. hidden_states = self.dense(hidden_states)
  377. hidden_states = self.intermediate_act_fn(hidden_states)
  378. return hidden_states
  379. class DebertaOutput(nn.Module):
  380. def __init__(self, config):
  381. super().__init__()
  382. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  383. self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
  384. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  385. self.config = config
  386. def forward(self, hidden_states, input_tensor):
  387. hidden_states = self.dense(hidden_states)
  388. hidden_states = self.dropout(hidden_states)
  389. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  390. return hidden_states
  391. class DebertaLayer(GradientCheckpointingLayer):
  392. def __init__(self, config):
  393. super().__init__()
  394. self.attention = DebertaAttention(config)
  395. self.intermediate = DebertaIntermediate(config)
  396. self.output = DebertaOutput(config)
  397. def forward(
  398. self,
  399. hidden_states,
  400. attention_mask,
  401. query_states=None,
  402. relative_pos=None,
  403. rel_embeddings=None,
  404. output_attentions: bool = False,
  405. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  406. attention_output, att_matrix = self.attention(
  407. hidden_states,
  408. attention_mask,
  409. output_attentions=output_attentions,
  410. query_states=query_states,
  411. relative_pos=relative_pos,
  412. rel_embeddings=rel_embeddings,
  413. )
  414. intermediate_output = self.intermediate(attention_output)
  415. layer_output = self.output(intermediate_output, attention_output)
  416. if output_attentions:
  417. return (layer_output, att_matrix)
  418. else:
  419. return (layer_output, None)
  420. class DebertaEncoder(nn.Module):
  421. """Modified BertEncoder with relative position bias support"""
  422. def __init__(self, config):
  423. super().__init__()
  424. self.layer = nn.ModuleList([DebertaLayer(config) for _ in range(config.num_hidden_layers)])
  425. self.relative_attention = getattr(config, "relative_attention", False)
  426. if self.relative_attention:
  427. self.max_relative_positions = getattr(config, "max_relative_positions", -1)
  428. if self.max_relative_positions < 1:
  429. self.max_relative_positions = config.max_position_embeddings
  430. self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size)
  431. self.gradient_checkpointing = False
  432. def get_rel_embedding(self):
  433. rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
  434. return rel_embeddings
  435. def get_attention_mask(self, attention_mask):
  436. if attention_mask.dim() <= 2:
  437. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  438. attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
  439. elif attention_mask.dim() == 3:
  440. attention_mask = attention_mask.unsqueeze(1)
  441. return attention_mask
  442. def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
  443. if self.relative_attention and relative_pos is None:
  444. if query_states is not None:
  445. relative_pos = build_relative_position(query_states, hidden_states)
  446. else:
  447. relative_pos = build_relative_position(hidden_states, hidden_states)
  448. return relative_pos
  449. def forward(
  450. self,
  451. hidden_states: torch.Tensor,
  452. attention_mask: torch.Tensor,
  453. output_hidden_states: bool = True,
  454. output_attentions: bool = False,
  455. query_states=None,
  456. relative_pos=None,
  457. return_dict: bool = True,
  458. ):
  459. attention_mask = self.get_attention_mask(attention_mask)
  460. relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
  461. all_hidden_states: tuple[torch.Tensor] | None = (hidden_states,) if output_hidden_states else None
  462. all_attentions = () if output_attentions else None
  463. next_kv = hidden_states
  464. rel_embeddings = self.get_rel_embedding()
  465. for i, layer_module in enumerate(self.layer):
  466. hidden_states, att_m = layer_module(
  467. next_kv,
  468. attention_mask,
  469. query_states=query_states,
  470. relative_pos=relative_pos,
  471. rel_embeddings=rel_embeddings,
  472. output_attentions=output_attentions,
  473. )
  474. if output_hidden_states:
  475. all_hidden_states = all_hidden_states + (hidden_states,)
  476. if query_states is not None:
  477. query_states = hidden_states
  478. else:
  479. next_kv = hidden_states
  480. if output_attentions:
  481. all_attentions = all_attentions + (att_m,)
  482. if not return_dict:
  483. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  484. return BaseModelOutput(
  485. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  486. )
  487. @auto_docstring
  488. class DebertaPreTrainedModel(PreTrainedModel):
  489. config: DebertaConfig
  490. base_model_prefix = "deberta"
  491. _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
  492. supports_gradient_checkpointing = True
  493. @torch.no_grad()
  494. def _init_weights(self, module):
  495. """Initialize the weights."""
  496. super()._init_weights(module)
  497. if isinstance(module, DisentangledSelfAttention):
  498. init.zeros_(module.q_bias)
  499. init.zeros_(module.v_bias)
  500. elif isinstance(module, (LegacyDebertaLMPredictionHead, DebertaLMPredictionHead)):
  501. init.zeros_(module.bias)
  502. elif isinstance(module, DebertaEmbeddings):
  503. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  504. @auto_docstring
  505. class DebertaModel(DebertaPreTrainedModel):
  506. def __init__(self, config):
  507. super().__init__(config)
  508. self.embeddings = DebertaEmbeddings(config)
  509. self.encoder = DebertaEncoder(config)
  510. self.z_steps = 0
  511. self.config = config
  512. # Initialize weights and apply final processing
  513. self.post_init()
  514. def get_input_embeddings(self):
  515. return self.embeddings.word_embeddings
  516. def set_input_embeddings(self, new_embeddings):
  517. self.embeddings.word_embeddings = new_embeddings
  518. @auto_docstring
  519. def forward(
  520. self,
  521. input_ids: torch.Tensor | None = None,
  522. attention_mask: torch.Tensor | None = None,
  523. token_type_ids: torch.Tensor | None = None,
  524. position_ids: torch.Tensor | None = None,
  525. inputs_embeds: torch.Tensor | None = None,
  526. output_attentions: bool | None = None,
  527. output_hidden_states: bool | None = None,
  528. return_dict: bool | None = None,
  529. **kwargs,
  530. ) -> tuple | BaseModelOutput:
  531. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  532. output_hidden_states = (
  533. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  534. )
  535. return_dict = return_dict if return_dict is not None else self.config.return_dict
  536. if input_ids is not None and inputs_embeds is not None:
  537. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  538. elif input_ids is not None:
  539. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  540. input_shape = input_ids.size()
  541. elif inputs_embeds is not None:
  542. input_shape = inputs_embeds.size()[:-1]
  543. else:
  544. raise ValueError("You have to specify either input_ids or inputs_embeds")
  545. device = input_ids.device if input_ids is not None else inputs_embeds.device
  546. if attention_mask is None:
  547. attention_mask = torch.ones(input_shape, device=device)
  548. if token_type_ids is None:
  549. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  550. embedding_output = self.embeddings(
  551. input_ids=input_ids,
  552. token_type_ids=token_type_ids,
  553. position_ids=position_ids,
  554. mask=attention_mask,
  555. inputs_embeds=inputs_embeds,
  556. )
  557. encoder_outputs = self.encoder(
  558. embedding_output,
  559. attention_mask,
  560. output_hidden_states=True,
  561. output_attentions=output_attentions,
  562. return_dict=return_dict,
  563. )
  564. encoded_layers = encoder_outputs[1]
  565. if self.z_steps > 1:
  566. hidden_states = encoded_layers[-2]
  567. layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
  568. query_states = encoded_layers[-1]
  569. rel_embeddings = self.encoder.get_rel_embedding()
  570. attention_mask = self.encoder.get_attention_mask(attention_mask)
  571. rel_pos = self.encoder.get_rel_pos(embedding_output)
  572. for layer in layers[1:]:
  573. query_states = layer(
  574. hidden_states,
  575. attention_mask,
  576. output_attentions=False,
  577. query_states=query_states,
  578. relative_pos=rel_pos,
  579. rel_embeddings=rel_embeddings,
  580. )
  581. encoded_layers.append(query_states)
  582. sequence_output = encoded_layers[-1]
  583. if not return_dict:
  584. return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
  585. return BaseModelOutput(
  586. last_hidden_state=sequence_output,
  587. hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
  588. attentions=encoder_outputs.attentions,
  589. )
  590. class LegacyDebertaPredictionHeadTransform(nn.Module):
  591. def __init__(self, config):
  592. super().__init__()
  593. self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
  594. self.dense = nn.Linear(config.hidden_size, self.embedding_size)
  595. if isinstance(config.hidden_act, str):
  596. self.transform_act_fn = ACT2FN[config.hidden_act]
  597. else:
  598. self.transform_act_fn = config.hidden_act
  599. self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
  600. def forward(self, hidden_states):
  601. hidden_states = self.dense(hidden_states)
  602. hidden_states = self.transform_act_fn(hidden_states)
  603. hidden_states = self.LayerNorm(hidden_states)
  604. return hidden_states
  605. class LegacyDebertaLMPredictionHead(nn.Module):
  606. def __init__(self, config):
  607. super().__init__()
  608. self.transform = LegacyDebertaPredictionHeadTransform(config)
  609. self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
  610. # The output weights are the same as the input embeddings, but there is
  611. # an output-only bias for each token.
  612. self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=True)
  613. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  614. def forward(self, hidden_states):
  615. hidden_states = self.transform(hidden_states)
  616. hidden_states = self.decoder(hidden_states)
  617. return hidden_states
  618. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->LegacyDeberta
  619. class LegacyDebertaOnlyMLMHead(nn.Module):
  620. def __init__(self, config):
  621. super().__init__()
  622. self.predictions = LegacyDebertaLMPredictionHead(config)
  623. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  624. prediction_scores = self.predictions(sequence_output)
  625. return prediction_scores
  626. class DebertaLMPredictionHead(nn.Module):
  627. """https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270"""
  628. def __init__(self, config):
  629. super().__init__()
  630. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  631. if isinstance(config.hidden_act, str):
  632. self.transform_act_fn = ACT2FN[config.hidden_act]
  633. else:
  634. self.transform_act_fn = config.hidden_act
  635. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=True)
  636. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  637. # note that the input embeddings must be passed as an argument
  638. def forward(self, hidden_states, word_embeddings):
  639. hidden_states = self.dense(hidden_states)
  640. hidden_states = self.transform_act_fn(hidden_states)
  641. hidden_states = self.LayerNorm(
  642. hidden_states
  643. ) # original used MaskedLayerNorm, but passed no mask. This is equivalent.
  644. hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
  645. return hidden_states
  646. class DebertaOnlyMLMHead(nn.Module):
  647. def __init__(self, config):
  648. super().__init__()
  649. self.lm_head = DebertaLMPredictionHead(config)
  650. # note that the input embeddings must be passed as an argument
  651. def forward(self, sequence_output, word_embeddings):
  652. prediction_scores = self.lm_head(sequence_output, word_embeddings)
  653. return prediction_scores
  654. @auto_docstring
  655. class DebertaForMaskedLM(DebertaPreTrainedModel):
  656. _tied_weights_keys = {
  657. "cls.predictions.decoder.bias": "cls.predictions.bias",
  658. "cls.predictions.decoder.weight": "deberta.embeddings.word_embeddings.weight",
  659. }
  660. def __init__(self, config):
  661. super().__init__(config)
  662. self.legacy = config.legacy
  663. self.deberta = DebertaModel(config)
  664. if self.legacy:
  665. self.cls = LegacyDebertaOnlyMLMHead(config)
  666. else:
  667. self._tied_weights_keys = {
  668. "lm_predictions.lm_head.weight": "deberta.embeddings.word_embeddings.weight",
  669. }
  670. self.lm_predictions = DebertaOnlyMLMHead(config)
  671. # Initialize weights and apply final processing
  672. self.post_init()
  673. def get_output_embeddings(self):
  674. if self.legacy:
  675. return self.cls.predictions.decoder
  676. else:
  677. return self.lm_predictions.lm_head.dense
  678. def set_output_embeddings(self, new_embeddings):
  679. if self.legacy:
  680. self.cls.predictions.decoder = new_embeddings
  681. self.cls.predictions.bias = new_embeddings.bias
  682. else:
  683. self.lm_predictions.lm_head.dense = new_embeddings
  684. self.lm_predictions.lm_head.bias = new_embeddings.bias
  685. @auto_docstring
  686. def forward(
  687. self,
  688. input_ids: torch.Tensor | None = None,
  689. attention_mask: torch.Tensor | None = None,
  690. token_type_ids: torch.Tensor | None = None,
  691. position_ids: torch.Tensor | None = None,
  692. inputs_embeds: torch.Tensor | None = None,
  693. labels: torch.Tensor | None = None,
  694. output_attentions: bool | None = None,
  695. output_hidden_states: bool | None = None,
  696. return_dict: bool | None = None,
  697. **kwargs,
  698. ) -> tuple | MaskedLMOutput:
  699. r"""
  700. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  701. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  702. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  703. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  704. """
  705. return_dict = return_dict if return_dict is not None else self.config.return_dict
  706. outputs = self.deberta(
  707. input_ids,
  708. attention_mask=attention_mask,
  709. token_type_ids=token_type_ids,
  710. position_ids=position_ids,
  711. inputs_embeds=inputs_embeds,
  712. output_attentions=output_attentions,
  713. output_hidden_states=output_hidden_states,
  714. return_dict=return_dict,
  715. )
  716. sequence_output = outputs[0]
  717. if self.legacy:
  718. prediction_scores = self.cls(sequence_output)
  719. else:
  720. prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings)
  721. masked_lm_loss = None
  722. if labels is not None:
  723. loss_fct = CrossEntropyLoss() # -100 index = padding token
  724. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  725. if not return_dict:
  726. output = (prediction_scores,) + outputs[1:]
  727. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  728. return MaskedLMOutput(
  729. loss=masked_lm_loss,
  730. logits=prediction_scores,
  731. hidden_states=outputs.hidden_states,
  732. attentions=outputs.attentions,
  733. )
  734. class ContextPooler(nn.Module):
  735. def __init__(self, config):
  736. super().__init__()
  737. self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
  738. self.dropout = nn.Dropout(config.pooler_dropout)
  739. self.config = config
  740. def forward(self, hidden_states):
  741. # We "pool" the model by simply taking the hidden state corresponding
  742. # to the first token.
  743. context_token = hidden_states[:, 0]
  744. context_token = self.dropout(context_token)
  745. pooled_output = self.dense(context_token)
  746. pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
  747. return pooled_output
  748. @property
  749. def output_dim(self):
  750. return self.config.hidden_size
  751. @auto_docstring(
  752. custom_intro="""
  753. DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  754. pooled output) e.g. for GLUE tasks.
  755. """
  756. )
  757. class DebertaForSequenceClassification(DebertaPreTrainedModel):
  758. def __init__(self, config):
  759. super().__init__(config)
  760. num_labels = getattr(config, "num_labels", 2)
  761. self.num_labels = num_labels
  762. self.deberta = DebertaModel(config)
  763. self.pooler = ContextPooler(config)
  764. output_dim = self.pooler.output_dim
  765. self.classifier = nn.Linear(output_dim, num_labels)
  766. drop_out = getattr(config, "cls_dropout", None)
  767. drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
  768. self.dropout = nn.Dropout(drop_out)
  769. # Initialize weights and apply final processing
  770. self.post_init()
  771. def get_input_embeddings(self):
  772. return self.deberta.get_input_embeddings()
  773. def set_input_embeddings(self, new_embeddings):
  774. self.deberta.set_input_embeddings(new_embeddings)
  775. @auto_docstring
  776. def forward(
  777. self,
  778. input_ids: torch.Tensor | None = None,
  779. attention_mask: torch.Tensor | None = None,
  780. token_type_ids: torch.Tensor | None = None,
  781. position_ids: torch.Tensor | None = None,
  782. inputs_embeds: torch.Tensor | None = None,
  783. labels: torch.Tensor | None = None,
  784. output_attentions: bool | None = None,
  785. output_hidden_states: bool | None = None,
  786. return_dict: bool | None = None,
  787. **kwargs,
  788. ) -> tuple | SequenceClassifierOutput:
  789. r"""
  790. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  791. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  792. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  793. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  794. """
  795. return_dict = return_dict if return_dict is not None else self.config.return_dict
  796. outputs = self.deberta(
  797. input_ids,
  798. token_type_ids=token_type_ids,
  799. attention_mask=attention_mask,
  800. position_ids=position_ids,
  801. inputs_embeds=inputs_embeds,
  802. output_attentions=output_attentions,
  803. output_hidden_states=output_hidden_states,
  804. return_dict=return_dict,
  805. )
  806. encoder_layer = outputs[0]
  807. pooled_output = self.pooler(encoder_layer)
  808. pooled_output = self.dropout(pooled_output)
  809. logits = self.classifier(pooled_output)
  810. loss = None
  811. if labels is not None:
  812. if self.config.problem_type is None:
  813. if self.num_labels == 1:
  814. # regression task
  815. loss_fn = nn.MSELoss()
  816. logits = logits.view(-1).to(labels.dtype)
  817. loss = loss_fn(logits, labels.view(-1))
  818. elif labels.dim() == 1 or labels.size(-1) == 1:
  819. label_index = (labels >= 0).nonzero()
  820. labels = labels.long()
  821. if label_index.size(0) > 0:
  822. labeled_logits = torch.gather(
  823. logits, 0, label_index.expand(label_index.size(0), logits.size(1))
  824. )
  825. labels = torch.gather(labels, 0, label_index.view(-1))
  826. loss_fct = CrossEntropyLoss()
  827. loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
  828. else:
  829. loss = torch.tensor(0).to(logits)
  830. else:
  831. log_softmax = nn.LogSoftmax(-1)
  832. loss = -((log_softmax(logits) * labels).sum(-1)).mean()
  833. elif self.config.problem_type == "regression":
  834. loss_fct = MSELoss()
  835. if self.num_labels == 1:
  836. loss = loss_fct(logits.squeeze(), labels.squeeze())
  837. else:
  838. loss = loss_fct(logits, labels)
  839. elif self.config.problem_type == "single_label_classification":
  840. loss_fct = CrossEntropyLoss()
  841. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  842. elif self.config.problem_type == "multi_label_classification":
  843. loss_fct = BCEWithLogitsLoss()
  844. loss = loss_fct(logits, labels)
  845. if not return_dict:
  846. output = (logits,) + outputs[1:]
  847. return ((loss,) + output) if loss is not None else output
  848. return SequenceClassifierOutput(
  849. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  850. )
  851. @auto_docstring
  852. class DebertaForTokenClassification(DebertaPreTrainedModel):
  853. def __init__(self, config):
  854. super().__init__(config)
  855. self.num_labels = config.num_labels
  856. self.deberta = DebertaModel(config)
  857. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  858. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  859. # Initialize weights and apply final processing
  860. self.post_init()
  861. @auto_docstring
  862. def forward(
  863. self,
  864. input_ids: torch.Tensor | None = None,
  865. attention_mask: torch.Tensor | None = None,
  866. token_type_ids: torch.Tensor | None = None,
  867. position_ids: torch.Tensor | None = None,
  868. inputs_embeds: torch.Tensor | None = None,
  869. labels: torch.Tensor | None = None,
  870. output_attentions: bool | None = None,
  871. output_hidden_states: bool | None = None,
  872. return_dict: bool | None = None,
  873. **kwargs,
  874. ) -> tuple | TokenClassifierOutput:
  875. r"""
  876. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  877. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  878. """
  879. return_dict = return_dict if return_dict is not None else self.config.return_dict
  880. outputs = self.deberta(
  881. input_ids,
  882. attention_mask=attention_mask,
  883. token_type_ids=token_type_ids,
  884. position_ids=position_ids,
  885. inputs_embeds=inputs_embeds,
  886. output_attentions=output_attentions,
  887. output_hidden_states=output_hidden_states,
  888. return_dict=return_dict,
  889. )
  890. sequence_output = outputs[0]
  891. sequence_output = self.dropout(sequence_output)
  892. logits = self.classifier(sequence_output)
  893. loss = None
  894. if labels is not None:
  895. loss_fct = CrossEntropyLoss()
  896. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  897. if not return_dict:
  898. output = (logits,) + outputs[1:]
  899. return ((loss,) + output) if loss is not None else output
  900. return TokenClassifierOutput(
  901. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  902. )
  903. @auto_docstring
  904. class DebertaForQuestionAnswering(DebertaPreTrainedModel):
  905. def __init__(self, config):
  906. super().__init__(config)
  907. self.num_labels = config.num_labels
  908. self.deberta = DebertaModel(config)
  909. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  910. # Initialize weights and apply final processing
  911. self.post_init()
  912. @auto_docstring
  913. def forward(
  914. self,
  915. input_ids: torch.Tensor | None = None,
  916. attention_mask: torch.Tensor | None = None,
  917. token_type_ids: torch.Tensor | None = None,
  918. position_ids: torch.Tensor | None = None,
  919. inputs_embeds: torch.Tensor | None = None,
  920. start_positions: torch.Tensor | None = None,
  921. end_positions: torch.Tensor | None = None,
  922. output_attentions: bool | None = None,
  923. output_hidden_states: bool | None = None,
  924. return_dict: bool | None = None,
  925. **kwargs,
  926. ) -> tuple | QuestionAnsweringModelOutput:
  927. return_dict = return_dict if return_dict is not None else self.config.return_dict
  928. outputs = self.deberta(
  929. input_ids,
  930. attention_mask=attention_mask,
  931. token_type_ids=token_type_ids,
  932. position_ids=position_ids,
  933. inputs_embeds=inputs_embeds,
  934. output_attentions=output_attentions,
  935. output_hidden_states=output_hidden_states,
  936. return_dict=return_dict,
  937. )
  938. sequence_output = outputs[0]
  939. logits = self.qa_outputs(sequence_output)
  940. start_logits, end_logits = logits.split(1, dim=-1)
  941. start_logits = start_logits.squeeze(-1).contiguous()
  942. end_logits = end_logits.squeeze(-1).contiguous()
  943. total_loss = None
  944. if start_positions is not None and end_positions is not None:
  945. # If we are on multi-GPU, split add a dimension
  946. if len(start_positions.size()) > 1:
  947. start_positions = start_positions.squeeze(-1)
  948. if len(end_positions.size()) > 1:
  949. end_positions = end_positions.squeeze(-1)
  950. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  951. ignored_index = start_logits.size(1)
  952. start_positions = start_positions.clamp(0, ignored_index)
  953. end_positions = end_positions.clamp(0, ignored_index)
  954. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  955. start_loss = loss_fct(start_logits, start_positions)
  956. end_loss = loss_fct(end_logits, end_positions)
  957. total_loss = (start_loss + end_loss) / 2
  958. if not return_dict:
  959. output = (start_logits, end_logits) + outputs[1:]
  960. return ((total_loss,) + output) if total_loss is not None else output
  961. return QuestionAnsweringModelOutput(
  962. loss=total_loss,
  963. start_logits=start_logits,
  964. end_logits=end_logits,
  965. hidden_states=outputs.hidden_states,
  966. attentions=outputs.attentions,
  967. )
  968. __all__ = [
  969. "DebertaForMaskedLM",
  970. "DebertaForQuestionAnswering",
  971. "DebertaForSequenceClassification",
  972. "DebertaForTokenClassification",
  973. "DebertaModel",
  974. "DebertaPreTrainedModel",
  975. ]