modeling_splinter.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. # Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Splinter model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, ModelOutput, QuestionAnsweringModelOutput
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...pytorch_utils import apply_chunking_to_forward
  27. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
  28. from ...utils.generic import merge_with_config_defaults
  29. from ...utils.output_capturing import capture_outputs
  30. from .configuration_splinter import SplinterConfig
  31. logger = logging.get_logger(__name__)
  32. class SplinterEmbeddings(nn.Module):
  33. """Construct the embeddings from word, position and token_type embeddings."""
  34. def __init__(self, config):
  35. super().__init__()
  36. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  37. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  38. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  39. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  40. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  41. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  42. self.register_buffer(
  43. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  44. )
  45. def forward(
  46. self,
  47. input_ids: torch.LongTensor | None = None,
  48. token_type_ids: torch.LongTensor | None = None,
  49. position_ids: torch.LongTensor | None = None,
  50. inputs_embeds: torch.FloatTensor | None = None,
  51. ) -> tuple:
  52. if input_ids is not None:
  53. input_shape = input_ids.size()
  54. else:
  55. input_shape = inputs_embeds.size()[:-1]
  56. seq_length = input_shape[1]
  57. if position_ids is None:
  58. position_ids = self.position_ids[:, :seq_length]
  59. if token_type_ids is None:
  60. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  61. if inputs_embeds is None:
  62. inputs_embeds = self.word_embeddings(input_ids)
  63. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  64. embeddings = inputs_embeds + token_type_embeddings
  65. position_embeddings = self.position_embeddings(position_ids)
  66. embeddings += position_embeddings
  67. embeddings = self.LayerNorm(embeddings)
  68. embeddings = self.dropout(embeddings)
  69. return embeddings
  70. # Copied from transformers.models.align.modeling_align.eager_attention_forward
  71. def eager_attention_forward(
  72. module: nn.Module,
  73. query: torch.Tensor,
  74. key: torch.Tensor,
  75. value: torch.Tensor,
  76. attention_mask: torch.Tensor | None,
  77. scaling: float,
  78. dropout: float = 0.0,
  79. **kwargs,
  80. ):
  81. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  82. if attention_mask is not None:
  83. attn_weights = attn_weights + attention_mask
  84. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  85. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  86. attn_output = torch.matmul(attn_weights, value)
  87. attn_output = attn_output.transpose(1, 2).contiguous()
  88. return attn_output, attn_weights
  89. # Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->Splinter
  90. class SplinterSelfAttention(nn.Module):
  91. def __init__(self, config):
  92. super().__init__()
  93. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  94. raise ValueError(
  95. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  96. f"heads ({config.num_attention_heads})"
  97. )
  98. self.config = config
  99. self.num_attention_heads = config.num_attention_heads
  100. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  101. self.all_head_size = self.num_attention_heads * self.attention_head_size
  102. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  103. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  104. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  105. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  106. self.attention_dropout = config.attention_probs_dropout_prob
  107. self.scaling = self.attention_head_size**-0.5
  108. def forward(
  109. self,
  110. hidden_states: torch.Tensor,
  111. attention_mask: torch.FloatTensor | None = None,
  112. **kwargs: Unpack[TransformersKwargs],
  113. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  114. input_shape = hidden_states.shape[:-1]
  115. hidden_shape = (*input_shape, -1, self.attention_head_size)
  116. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  117. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  118. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  119. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  120. self.config._attn_implementation, eager_attention_forward
  121. )
  122. attn_output, attn_weights = attention_interface(
  123. self,
  124. query_states,
  125. key_states,
  126. value_states,
  127. attention_mask,
  128. dropout=0.0 if not self.training else self.attention_dropout,
  129. scaling=self.scaling,
  130. **kwargs,
  131. )
  132. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  133. return attn_output, attn_weights
  134. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Splinter
  135. class SplinterSelfOutput(nn.Module):
  136. def __init__(self, config):
  137. super().__init__()
  138. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  139. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  140. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  141. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  142. hidden_states = self.dense(hidden_states)
  143. hidden_states = self.dropout(hidden_states)
  144. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  145. return hidden_states
  146. # Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->Splinter
  147. class SplinterAttention(nn.Module):
  148. def __init__(self, config):
  149. super().__init__()
  150. self.self = SplinterSelfAttention(config)
  151. self.output = SplinterSelfOutput(config)
  152. def forward(
  153. self,
  154. hidden_states: torch.Tensor,
  155. attention_mask: torch.FloatTensor | None = None,
  156. **kwargs: Unpack[TransformersKwargs],
  157. ) -> torch.Tensor:
  158. residual = hidden_states
  159. hidden_states, _ = self.self(
  160. hidden_states,
  161. attention_mask=attention_mask,
  162. **kwargs,
  163. )
  164. hidden_states = self.output(hidden_states, residual)
  165. return hidden_states
  166. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Splinter
  167. class SplinterIntermediate(nn.Module):
  168. def __init__(self, config):
  169. super().__init__()
  170. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  171. if isinstance(config.hidden_act, str):
  172. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  173. else:
  174. self.intermediate_act_fn = config.hidden_act
  175. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  176. hidden_states = self.dense(hidden_states)
  177. hidden_states = self.intermediate_act_fn(hidden_states)
  178. return hidden_states
  179. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Splinter
  180. class SplinterOutput(nn.Module):
  181. def __init__(self, config):
  182. super().__init__()
  183. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  184. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  185. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  186. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  187. hidden_states = self.dense(hidden_states)
  188. hidden_states = self.dropout(hidden_states)
  189. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  190. return hidden_states
  191. # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->Splinter
  192. class SplinterLayer(GradientCheckpointingLayer):
  193. def __init__(self, config):
  194. super().__init__()
  195. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  196. self.seq_len_dim = 1
  197. self.attention = SplinterAttention(config)
  198. self.intermediate = SplinterIntermediate(config)
  199. self.output = SplinterOutput(config)
  200. def forward(
  201. self,
  202. hidden_states: torch.Tensor,
  203. attention_mask: torch.FloatTensor | None = None,
  204. **kwargs: Unpack[TransformersKwargs],
  205. ) -> torch.Tensor:
  206. hidden_states = self.attention(
  207. hidden_states,
  208. attention_mask=attention_mask,
  209. **kwargs,
  210. )
  211. hidden_states = apply_chunking_to_forward(
  212. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, hidden_states
  213. )
  214. return hidden_states
  215. def feed_forward_chunk(self, attention_output):
  216. intermediate_output = self.intermediate(attention_output)
  217. layer_output = self.output(intermediate_output, attention_output)
  218. return layer_output
  219. # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->Splinter
  220. class SplinterEncoder(nn.Module):
  221. def __init__(self, config):
  222. super().__init__()
  223. self.config = config
  224. self.layer = nn.ModuleList([SplinterLayer(config) for i in range(config.num_hidden_layers)])
  225. self.gradient_checkpointing = False
  226. def forward(
  227. self,
  228. hidden_states: torch.Tensor,
  229. attention_mask: torch.FloatTensor | None = None,
  230. **kwargs: Unpack[TransformersKwargs],
  231. ) -> BaseModelOutput:
  232. for layer_module in self.layer:
  233. hidden_states = layer_module(
  234. hidden_states,
  235. attention_mask,
  236. **kwargs,
  237. )
  238. return BaseModelOutput(
  239. last_hidden_state=hidden_states,
  240. )
  241. @auto_docstring
  242. class SplinterPreTrainedModel(PreTrainedModel):
  243. config: SplinterConfig
  244. base_model_prefix = "splinter"
  245. supports_gradient_checkpointing = True
  246. _can_record_outputs = {
  247. "hidden_states": SplinterLayer,
  248. "attentions": SplinterSelfAttention,
  249. }
  250. def _init_weights(self, module):
  251. super()._init_weights(module)
  252. if isinstance(module, SplinterEmbeddings):
  253. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  254. @auto_docstring
  255. class SplinterModel(SplinterPreTrainedModel):
  256. """
  257. The model is an encoder (with only self-attention) following the architecture described in [Attention is all you
  258. need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
  259. Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  260. """
  261. def __init__(self, config):
  262. super().__init__(config)
  263. self.config = config
  264. self.embeddings = SplinterEmbeddings(config)
  265. self.encoder = SplinterEncoder(config)
  266. # Initialize weights and apply final processing
  267. self.post_init()
  268. def get_input_embeddings(self):
  269. return self.embeddings.word_embeddings
  270. def set_input_embeddings(self, value):
  271. self.embeddings.word_embeddings = value
  272. @merge_with_config_defaults
  273. @capture_outputs
  274. @auto_docstring
  275. def forward(
  276. self,
  277. input_ids: torch.Tensor | None = None,
  278. attention_mask: torch.Tensor | None = None,
  279. token_type_ids: torch.Tensor | None = None,
  280. position_ids: torch.Tensor | None = None,
  281. inputs_embeds: torch.Tensor | None = None,
  282. **kwargs: Unpack[TransformersKwargs],
  283. ) -> tuple | BaseModelOutput:
  284. r"""
  285. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  286. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  287. 1]`:
  288. - 0 corresponds to a *sentence A* token,
  289. - 1 corresponds to a *sentence B* token.
  290. [What are token type IDs?](../glossary#token-type-ids)
  291. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  292. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  293. config.max_position_embeddings - 1]`.
  294. [What are position IDs?](../glossary#position-ids)
  295. """
  296. if input_ids is not None and inputs_embeds is not None:
  297. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  298. elif input_ids is not None:
  299. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  300. input_shape = input_ids.size()
  301. elif inputs_embeds is not None:
  302. input_shape = inputs_embeds.size()[:-1]
  303. else:
  304. raise ValueError("You have to specify either input_ids or inputs_embeds")
  305. batch_size, seq_length = input_shape
  306. device = input_ids.device if input_ids is not None else inputs_embeds.device
  307. if attention_mask is None:
  308. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  309. if token_type_ids is None:
  310. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  311. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  312. # ourselves in which case we just need to make it broadcastable to all heads.
  313. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  314. embedding_output = self.embeddings(
  315. input_ids=input_ids,
  316. position_ids=position_ids,
  317. token_type_ids=token_type_ids,
  318. inputs_embeds=inputs_embeds,
  319. )
  320. encoder_outputs = self.encoder(
  321. embedding_output,
  322. attention_mask=extended_attention_mask,
  323. **kwargs,
  324. )
  325. sequence_output = encoder_outputs[0]
  326. return BaseModelOutput(
  327. last_hidden_state=sequence_output,
  328. )
  329. class SplinterFullyConnectedLayer(nn.Module):
  330. def __init__(self, input_dim, output_dim, hidden_act="gelu"):
  331. super().__init__()
  332. self.input_dim = input_dim
  333. self.output_dim = output_dim
  334. self.dense = nn.Linear(self.input_dim, self.output_dim)
  335. self.act_fn = ACT2FN[hidden_act]
  336. self.LayerNorm = nn.LayerNorm(self.output_dim)
  337. def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  338. hidden_states = self.dense(inputs)
  339. hidden_states = self.act_fn(hidden_states)
  340. hidden_states = self.LayerNorm(hidden_states)
  341. return hidden_states
  342. class QuestionAwareSpanSelectionHead(nn.Module):
  343. """
  344. Implementation of Question-Aware Span Selection (QASS) head, described in Splinter's paper:
  345. """
  346. def __init__(self, config):
  347. super().__init__()
  348. self.query_start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
  349. self.query_end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
  350. self.start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
  351. self.end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
  352. self.start_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  353. self.end_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  354. def forward(self, inputs, positions):
  355. _, _, dim = inputs.size()
  356. index = positions.unsqueeze(-1).repeat(1, 1, dim) # [batch_size, num_positions, dim]
  357. gathered_reps = torch.gather(inputs, dim=1, index=index) # [batch_size, num_positions, dim]
  358. query_start_reps = self.query_start_transform(gathered_reps) # [batch_size, num_positions, dim]
  359. query_end_reps = self.query_end_transform(gathered_reps) # [batch_size, num_positions, dim]
  360. start_reps = self.start_transform(inputs) # [batch_size, seq_length, dim]
  361. end_reps = self.end_transform(inputs) # [batch_size, seq_length, dim]
  362. hidden_states = self.start_classifier(query_start_reps) # [batch_size, num_positions, dim]
  363. start_reps = start_reps.permute(0, 2, 1) # [batch_size, dim, seq_length]
  364. start_logits = torch.matmul(hidden_states, start_reps)
  365. hidden_states = self.end_classifier(query_end_reps)
  366. end_reps = end_reps.permute(0, 2, 1)
  367. end_logits = torch.matmul(hidden_states, end_reps)
  368. return start_logits, end_logits
  369. @auto_docstring
  370. class SplinterForQuestionAnswering(SplinterPreTrainedModel):
  371. def __init__(self, config):
  372. super().__init__(config)
  373. self.splinter = SplinterModel(config)
  374. self.splinter_qass = QuestionAwareSpanSelectionHead(config)
  375. self.question_token_id = config.question_token_id
  376. # Initialize weights and apply final processing
  377. self.post_init()
  378. @can_return_tuple
  379. @auto_docstring
  380. def forward(
  381. self,
  382. input_ids: torch.Tensor | None = None,
  383. attention_mask: torch.Tensor | None = None,
  384. token_type_ids: torch.Tensor | None = None,
  385. position_ids: torch.Tensor | None = None,
  386. inputs_embeds: torch.Tensor | None = None,
  387. start_positions: torch.LongTensor | None = None,
  388. end_positions: torch.LongTensor | None = None,
  389. question_positions: torch.LongTensor | None = None,
  390. **kwargs: Unpack[TransformersKwargs],
  391. ) -> tuple | QuestionAnsweringModelOutput:
  392. r"""
  393. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  394. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  395. 1]`:
  396. - 0 corresponds to a *sentence A* token,
  397. - 1 corresponds to a *sentence B* token.
  398. [What are token type IDs?](../glossary#token-type-ids)
  399. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  400. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  401. config.max_position_embeddings - 1]`.
  402. [What are position IDs?](../glossary#position-ids)
  403. question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
  404. The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
  405. num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
  406. the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
  407. sequence_length)`.
  408. """
  409. question_positions_were_none = False
  410. if question_positions is None:
  411. if input_ids is not None:
  412. question_position_for_each_example = torch.argmax(
  413. (torch.eq(input_ids, self.question_token_id)).int(), dim=-1
  414. )
  415. else:
  416. question_position_for_each_example = torch.zeros(
  417. inputs_embeds.size(0), dtype=torch.long, layout=inputs_embeds.layout, device=inputs_embeds.device
  418. )
  419. question_positions = question_position_for_each_example.unsqueeze(-1)
  420. question_positions_were_none = True
  421. outputs = self.splinter(
  422. input_ids,
  423. attention_mask=attention_mask,
  424. token_type_ids=token_type_ids,
  425. position_ids=position_ids,
  426. inputs_embeds=inputs_embeds,
  427. **kwargs,
  428. )
  429. sequence_output = outputs[0]
  430. start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
  431. if question_positions_were_none:
  432. start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1)
  433. if attention_mask is not None:
  434. start_logits = start_logits + (1 - attention_mask) * torch.finfo(start_logits.dtype).min
  435. end_logits = end_logits + (1 - attention_mask) * torch.finfo(end_logits.dtype).min
  436. total_loss = None
  437. if start_positions is not None and end_positions is not None:
  438. # If we are on multi-GPU, split add a dimension
  439. if len(start_positions.size()) > 1:
  440. start_positions = start_positions.squeeze(-1)
  441. if len(end_positions.size()) > 1:
  442. end_positions = end_positions.squeeze(-1)
  443. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  444. ignored_index = start_logits.size(1)
  445. start_positions.clamp_(0, ignored_index)
  446. end_positions.clamp_(0, ignored_index)
  447. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  448. start_loss = loss_fct(start_logits, start_positions)
  449. end_loss = loss_fct(end_logits, end_positions)
  450. total_loss = (start_loss + end_loss) / 2
  451. return QuestionAnsweringModelOutput(
  452. loss=total_loss,
  453. start_logits=start_logits,
  454. end_logits=end_logits,
  455. hidden_states=outputs.hidden_states,
  456. attentions=outputs.attentions,
  457. )
  458. @dataclass
  459. @auto_docstring(
  460. custom_intro="""
  461. Class for outputs of Splinter as a span selection model.
  462. """
  463. )
  464. class SplinterForPreTrainingOutput(ModelOutput):
  465. r"""
  466. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided):
  467. Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
  468. start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
  469. Span-start scores (before SoftMax).
  470. end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
  471. Span-end scores (before SoftMax).
  472. """
  473. loss: torch.FloatTensor | None = None
  474. start_logits: torch.FloatTensor | None = None
  475. end_logits: torch.FloatTensor | None = None
  476. hidden_states: tuple[torch.FloatTensor] | None = None
  477. attentions: tuple[torch.FloatTensor] | None = None
  478. @auto_docstring(
  479. custom_intro="""
  480. Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task
  481. is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans
  482. instead.
  483. """
  484. )
  485. class SplinterForPreTraining(SplinterPreTrainedModel):
  486. def __init__(self, config):
  487. super().__init__(config)
  488. self.splinter = SplinterModel(config)
  489. self.splinter_qass = QuestionAwareSpanSelectionHead(config)
  490. self.question_token_id = config.question_token_id
  491. # Initialize weights and apply final processing
  492. self.post_init()
  493. @can_return_tuple
  494. @auto_docstring
  495. def forward(
  496. self,
  497. input_ids: torch.Tensor | None = None,
  498. attention_mask: torch.Tensor | None = None,
  499. token_type_ids: torch.Tensor | None = None,
  500. position_ids: torch.Tensor | None = None,
  501. inputs_embeds: torch.Tensor | None = None,
  502. start_positions: torch.LongTensor | None = None,
  503. end_positions: torch.LongTensor | None = None,
  504. question_positions: torch.LongTensor | None = None,
  505. **kwargs: Unpack[TransformersKwargs],
  506. ) -> tuple | SplinterForPreTrainingOutput:
  507. r"""
  508. input_ids (`torch.LongTensor` of shape `(batch_size, num_questions, sequence_length)`):
  509. Indices of input sequence tokens in the vocabulary.
  510. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  511. [`PreTrainedTokenizer.__call__`] for details.
  512. [What are input IDs?](../glossary#input-ids)
  513. token_type_ids (`torch.LongTensor` of shape `batch_size, num_questions, sequence_length`, *optional*):
  514. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  515. 1]`:
  516. - 0 corresponds to a *sentence A* token,
  517. - 1 corresponds to a *sentence B* token.
  518. [What are token type IDs?](../glossary#token-type-ids)
  519. position_ids (`torch.LongTensor` of shape `batch_size, num_questions, sequence_length`, *optional*):
  520. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  521. config.max_position_embeddings - 1]`.
  522. [What are position IDs?](../glossary#position-ids)
  523. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length, hidden_size)`, *optional*):
  524. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  525. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  526. model's internal embedding lookup matrix.
  527. start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
  528. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  529. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  530. are not taken into account for computing the loss.
  531. end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
  532. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  533. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  534. are not taken into account for computing the loss.
  535. question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
  536. The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
  537. num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
  538. the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
  539. sequence_length)`.
  540. """
  541. if question_positions is None and start_positions is not None and end_positions is not None:
  542. raise TypeError("question_positions must be specified in order to calculate the loss")
  543. elif question_positions is None and input_ids is None:
  544. raise TypeError("question_positions must be specified when inputs_embeds is used")
  545. elif question_positions is None:
  546. question_positions = self._prepare_question_positions(input_ids)
  547. outputs = self.splinter(
  548. input_ids,
  549. attention_mask=attention_mask,
  550. token_type_ids=token_type_ids,
  551. position_ids=position_ids,
  552. inputs_embeds=inputs_embeds,
  553. **kwargs,
  554. )
  555. sequence_output = outputs[0]
  556. batch_size, sequence_length, dim = sequence_output.size()
  557. # [batch_size, num_questions, sequence_length]
  558. start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
  559. num_questions = question_positions.size(1)
  560. if attention_mask is not None:
  561. attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
  562. batch_size, num_questions, sequence_length
  563. )
  564. start_logits = start_logits + (1 - attention_mask_for_each_question) * torch.finfo(start_logits.dtype).min
  565. end_logits = end_logits + (1 - attention_mask_for_each_question) * torch.finfo(end_logits.dtype).min
  566. total_loss = None
  567. # [batch_size, num_questions, sequence_length]
  568. if start_positions is not None and end_positions is not None:
  569. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  570. start_positions.clamp_(0, max(0, sequence_length - 1))
  571. end_positions.clamp_(0, max(0, sequence_length - 1))
  572. # Ignore zero positions in the loss. Splinter never predicts zero
  573. # during pretraining and zero is used for padding question
  574. # tokens as well as for start and end positions of padded
  575. # question tokens.
  576. loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
  577. start_loss = loss_fct(
  578. start_logits.view(batch_size * num_questions, sequence_length),
  579. start_positions.view(batch_size * num_questions),
  580. )
  581. end_loss = loss_fct(
  582. end_logits.view(batch_size * num_questions, sequence_length),
  583. end_positions.view(batch_size * num_questions),
  584. )
  585. total_loss = (start_loss + end_loss) / 2
  586. return SplinterForPreTrainingOutput(
  587. loss=total_loss,
  588. start_logits=start_logits,
  589. end_logits=end_logits,
  590. hidden_states=outputs.hidden_states,
  591. attentions=outputs.attentions,
  592. )
  593. def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor:
  594. rows, flat_positions = torch.where(input_ids == self.config.question_token_id)
  595. num_questions = torch.bincount(rows)
  596. positions = torch.full(
  597. (input_ids.size(0), num_questions.max()),
  598. self.config.pad_token_id,
  599. dtype=torch.long,
  600. device=input_ids.device,
  601. )
  602. torch_compilable_check(
  603. num_questions.size(0) == input_ids.size(0),
  604. "All samples in the batch must have at least one question token.",
  605. )
  606. cols = torch.cat([torch.arange(n) for n in num_questions])
  607. positions[rows, cols] = flat_positions
  608. return positions
  609. __all__ = [
  610. "SplinterForQuestionAnswering",
  611. "SplinterForPreTraining",
  612. "SplinterLayer",
  613. "SplinterModel",
  614. "SplinterPreTrainedModel",
  615. ]