modeling_opt.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch OPT model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutputWithPast,
  27. CausalLMOutputWithPast,
  28. QuestionAnsweringModelOutput,
  29. SequenceClassifierOutputWithPast,
  30. )
  31. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  34. from ...utils.generic import merge_with_config_defaults
  35. from ...utils.output_capturing import capture_outputs
  36. from .configuration_opt import OPTConfig
  37. logger = logging.get_logger(__name__)
  38. class OPTLearnedPositionalEmbedding(nn.Embedding):
  39. """
  40. This module learns positional embeddings up to a fixed maximum size.
  41. """
  42. def __init__(self, num_embeddings: int, embedding_dim: int):
  43. # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
  44. # and adjust num_embeddings appropriately. Other models don't have this hack
  45. self.offset = 2
  46. super().__init__(num_embeddings + self.offset, embedding_dim)
  47. def forward(
  48. self,
  49. attention_mask: torch.LongTensor,
  50. past_key_values_length: int = 0,
  51. position_ids: torch.LongTensor | None = None,
  52. ):
  53. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  54. if position_ids is None:
  55. position_ids = torch.cumsum(attention_mask, dim=1)
  56. position_ids = (position_ids * attention_mask - 1).long()
  57. # cut positions if `past_key_values_length` is > 0
  58. position_ids = position_ids[:, past_key_values_length:]
  59. return super().forward(position_ids + self.offset)
  60. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  61. def eager_attention_forward(
  62. module: nn.Module,
  63. query: torch.Tensor,
  64. key: torch.Tensor,
  65. value: torch.Tensor,
  66. attention_mask: torch.Tensor | None,
  67. scaling: float,
  68. dropout: float = 0.0,
  69. **kwargs,
  70. ):
  71. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  72. if attention_mask is not None:
  73. attn_weights = attn_weights + attention_mask
  74. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  75. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  76. attn_output = torch.matmul(attn_weights, value)
  77. attn_output = attn_output.transpose(1, 2).contiguous()
  78. return attn_output, attn_weights
  79. class OPTAttention(nn.Module):
  80. """Multi-headed attention from 'Attention Is All You Need' paper"""
  81. def __init__(
  82. self,
  83. config: OPTConfig,
  84. layer_idx: int | None = None,
  85. **kwargs,
  86. ):
  87. super().__init__()
  88. self.config = config
  89. self.embed_dim = config.hidden_size
  90. self.num_heads = config.num_attention_heads
  91. self.dropout = config.attention_dropout
  92. self.enable_bias = config.enable_bias
  93. self.layer_idx = layer_idx
  94. if layer_idx is None:
  95. logger.warning_once(
  96. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  97. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  98. "when creating this class."
  99. )
  100. self.head_dim = self.embed_dim // self.num_heads
  101. self.is_causal = True
  102. if (self.head_dim * self.num_heads) != self.embed_dim:
  103. raise ValueError(
  104. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  105. f" and `num_heads`: {self.num_heads})."
  106. )
  107. self.scaling = self.head_dim**-0.5
  108. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  109. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  110. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  111. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  112. def forward(
  113. self,
  114. hidden_states: torch.Tensor,
  115. past_key_values: Cache | None = None,
  116. attention_mask: torch.Tensor | None = None,
  117. output_attentions: bool = False,
  118. **kwargs,
  119. ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
  120. """Input shape: Batch x Time x Channel"""
  121. bsz, tgt_len, _ = hidden_states.size()
  122. # Scaling is susceptible to floating point arithmetics' inprecisions
  123. # which can lead to different results (this is dependent from model
  124. # to model, e.g. whisper is one such case). We therefore keep the
  125. # original order of scaling to follow the original implementation
  126. # and enforce no scaling (1.0) in the attention call below.
  127. query_states = self.q_proj(hidden_states) * self.scaling
  128. query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  129. key_states = self.k_proj(hidden_states)
  130. value_states = self.v_proj(hidden_states)
  131. key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  132. value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  133. if past_key_values is not None:
  134. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  135. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  136. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  137. self.config._attn_implementation, eager_attention_forward
  138. )
  139. attn_output, attn_weights = attention_interface(
  140. self,
  141. query_states,
  142. key_states,
  143. value_states,
  144. attention_mask,
  145. dropout=0.0 if not self.training else self.dropout,
  146. scaling=1.0,
  147. **kwargs,
  148. )
  149. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  150. attn_output = self.out_proj(attn_output)
  151. return attn_output, attn_weights
  152. class OPTDecoderLayer(GradientCheckpointingLayer):
  153. def __init__(self, config: OPTConfig, layer_idx: int | None = None):
  154. super().__init__()
  155. self.embed_dim = config.hidden_size
  156. self.self_attn = OPTAttention(config=config, layer_idx=layer_idx)
  157. self.do_layer_norm_before = config.do_layer_norm_before
  158. self.dropout = config.dropout
  159. self.activation_fn = ACT2FN[config.activation_function]
  160. self.self_attn_layer_norm = nn.LayerNorm(
  161. self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
  162. )
  163. self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
  164. self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
  165. self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
  166. def forward(
  167. self,
  168. hidden_states: torch.Tensor,
  169. attention_mask: torch.Tensor | None = None,
  170. past_key_values: Cache | None = None,
  171. use_cache: bool | None = False,
  172. position_ids: torch.LongTensor | None = None,
  173. **kwargs: Unpack[FlashAttentionKwargs],
  174. ) -> torch.Tensor:
  175. residual = hidden_states
  176. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  177. if self.do_layer_norm_before:
  178. hidden_states = self.self_attn_layer_norm(hidden_states)
  179. # Self Attention
  180. hidden_states, _ = self.self_attn(
  181. hidden_states=hidden_states,
  182. past_key_values=past_key_values,
  183. position_ids=position_ids,
  184. attention_mask=attention_mask,
  185. **kwargs,
  186. )
  187. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  188. hidden_states = residual + hidden_states
  189. # 350m applies layer norm AFTER attention
  190. if not self.do_layer_norm_before:
  191. hidden_states = self.self_attn_layer_norm(hidden_states)
  192. # Fully Connected
  193. hidden_states_shape = hidden_states.shape
  194. hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
  195. residual = hidden_states
  196. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  197. if self.do_layer_norm_before:
  198. hidden_states = self.final_layer_norm(hidden_states)
  199. hidden_states = self.fc1(hidden_states)
  200. hidden_states = self.activation_fn(hidden_states)
  201. hidden_states = self.fc2(hidden_states)
  202. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  203. hidden_states = (residual + hidden_states).view(hidden_states_shape)
  204. # 350m applies layer norm AFTER attention
  205. if not self.do_layer_norm_before:
  206. hidden_states = self.final_layer_norm(hidden_states)
  207. return hidden_states
  208. @auto_docstring
  209. class OPTPreTrainedModel(PreTrainedModel):
  210. config: OPTConfig
  211. base_model_prefix = "model"
  212. supports_gradient_checkpointing = True
  213. _no_split_modules = ["OPTDecoderLayer"]
  214. _supports_attention_backend = True
  215. _supports_flash_attn = True
  216. _supports_sdpa = True
  217. _supports_flex_attn = True
  218. _can_compile_fullgraph = True
  219. _can_record_outputs = {
  220. "hidden_states": OPTDecoderLayer,
  221. "attentions": OPTAttention,
  222. }
  223. class OPTDecoder(OPTPreTrainedModel):
  224. """
  225. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
  226. Args:
  227. config: OPTConfig
  228. """
  229. def __init__(self, config: OPTConfig):
  230. super().__init__(config)
  231. self.dropout = config.dropout
  232. self.layerdrop = config.layerdrop
  233. self.padding_idx = config.pad_token_id
  234. self.max_target_positions = config.max_position_embeddings
  235. self.vocab_size = config.vocab_size
  236. self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
  237. self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
  238. if config.word_embed_proj_dim != config.hidden_size:
  239. self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
  240. else:
  241. self.project_out = None
  242. if config.word_embed_proj_dim != config.hidden_size:
  243. self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
  244. else:
  245. self.project_in = None
  246. # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
  247. # with checkpoints that have been fine-tuned before transformers v4.20.1
  248. # see https://github.com/facebookresearch/metaseq/pull/164
  249. if config.do_layer_norm_before and not config._remove_final_layer_norm:
  250. self.final_layer_norm = nn.LayerNorm(
  251. config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
  252. )
  253. else:
  254. self.final_layer_norm = None
  255. self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  256. self.gradient_checkpointing = False
  257. # Initialize weights and apply final processing
  258. self.post_init()
  259. @merge_with_config_defaults
  260. @capture_outputs
  261. @auto_docstring
  262. def forward(
  263. self,
  264. input_ids: torch.LongTensor | None = None,
  265. attention_mask: torch.Tensor | None = None,
  266. past_key_values: Cache | None = None,
  267. inputs_embeds: torch.FloatTensor | None = None,
  268. use_cache: bool | None = None,
  269. position_ids: torch.LongTensor | None = None,
  270. **kwargs: Unpack[TransformersKwargs],
  271. ) -> BaseModelOutputWithPast:
  272. if (input_ids is None) ^ (inputs_embeds is not None):
  273. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  274. if input_ids is not None:
  275. input_ids = input_ids.view(-1, input_ids.shape[-1])
  276. if inputs_embeds is None:
  277. inputs_embeds = self.embed_tokens(input_ids)
  278. if use_cache and past_key_values is None:
  279. past_key_values = DynamicCache(config=self.config)
  280. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  281. if attention_mask is None:
  282. seq_length = past_seen_tokens + inputs_embeds.shape[1]
  283. attention_mask = torch.ones(inputs_embeds.shape[0], seq_length, device=inputs_embeds.device)
  284. # embed positions
  285. if position_ids is None:
  286. position_ids = torch.cumsum(attention_mask, dim=1)
  287. position_ids = (position_ids * attention_mask - 1).long()
  288. # cut positions if `past_seen_tokens` is > 0
  289. position_ids = position_ids[:, past_seen_tokens:]
  290. causal_mask = create_causal_mask(
  291. config=self.config,
  292. inputs_embeds=inputs_embeds,
  293. attention_mask=attention_mask,
  294. past_key_values=past_key_values,
  295. )
  296. pos_embeds = self.embed_positions(attention_mask, past_seen_tokens, position_ids=position_ids)
  297. if self.project_in is not None:
  298. inputs_embeds = self.project_in(inputs_embeds)
  299. hidden_states = inputs_embeds + pos_embeds.to(inputs_embeds.device)
  300. # decoder layers
  301. for idx, decoder_layer in enumerate(self.layers):
  302. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  303. if self.training:
  304. dropout_probability = torch.rand([])
  305. if dropout_probability < self.layerdrop:
  306. continue
  307. hidden_states = decoder_layer(
  308. hidden_states,
  309. attention_mask=causal_mask,
  310. position_ids=position_ids,
  311. past_key_values=past_key_values,
  312. use_cache=use_cache,
  313. **kwargs,
  314. )
  315. if self.final_layer_norm is not None:
  316. hidden_states = self.final_layer_norm(hidden_states)
  317. if self.project_out is not None:
  318. hidden_states = self.project_out(hidden_states)
  319. return BaseModelOutputWithPast(
  320. last_hidden_state=hidden_states,
  321. past_key_values=past_key_values,
  322. )
  323. @auto_docstring
  324. class OPTModel(OPTPreTrainedModel):
  325. def __init__(self, config: OPTConfig):
  326. super().__init__(config)
  327. self.decoder = OPTDecoder(config)
  328. # Initialize weights and apply final processing
  329. self.post_init()
  330. def get_input_embeddings(self):
  331. return self.decoder.embed_tokens
  332. def set_input_embeddings(self, value):
  333. self.decoder.embed_tokens = value
  334. @can_return_tuple
  335. @auto_docstring
  336. def forward(
  337. self,
  338. input_ids: torch.LongTensor | None = None,
  339. attention_mask: torch.Tensor | None = None,
  340. past_key_values: Cache | None = None,
  341. inputs_embeds: torch.FloatTensor | None = None,
  342. use_cache: bool | None = None,
  343. position_ids: torch.LongTensor | None = None,
  344. **kwargs: Unpack[TransformersKwargs],
  345. ) -> BaseModelOutputWithPast:
  346. decoder_outputs: BaseModelOutputWithPast = self.decoder(
  347. input_ids=input_ids,
  348. attention_mask=attention_mask,
  349. position_ids=position_ids,
  350. past_key_values=past_key_values,
  351. inputs_embeds=inputs_embeds,
  352. use_cache=use_cache,
  353. **kwargs,
  354. )
  355. return BaseModelOutputWithPast(
  356. last_hidden_state=decoder_outputs.last_hidden_state,
  357. past_key_values=decoder_outputs.past_key_values,
  358. hidden_states=decoder_outputs.hidden_states,
  359. attentions=decoder_outputs.attentions,
  360. )
  361. class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
  362. _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"}
  363. def __init__(self, config):
  364. super().__init__(config)
  365. self.model = OPTModel(config)
  366. # the lm_head weight is automatically tied to the embed tokens weight
  367. self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
  368. # Initialize weights and apply final processing
  369. self.post_init()
  370. def get_input_embeddings(self):
  371. return self.model.decoder.embed_tokens
  372. def set_input_embeddings(self, value):
  373. self.model.decoder.embed_tokens = value
  374. @can_return_tuple
  375. @auto_docstring
  376. def forward(
  377. self,
  378. input_ids: torch.LongTensor | None = None,
  379. attention_mask: torch.Tensor | None = None,
  380. past_key_values: Cache | None = None,
  381. inputs_embeds: torch.FloatTensor | None = None,
  382. labels: torch.LongTensor | None = None,
  383. use_cache: bool | None = None,
  384. position_ids: torch.LongTensor | None = None,
  385. logits_to_keep: int | torch.Tensor = 0,
  386. **kwargs: Unpack[TransformersKwargs],
  387. ) -> tuple | CausalLMOutputWithPast:
  388. r"""
  389. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  390. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  391. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  392. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  393. Example:
  394. ```python
  395. >>> from transformers import AutoTokenizer, OPTForCausalLM
  396. >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
  397. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
  398. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  399. >>> inputs = tokenizer(prompt, return_tensors="pt")
  400. >>> # Generate
  401. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  402. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  403. "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
  404. ```"""
  405. outputs: BaseModelOutputWithPast = self.model.decoder(
  406. input_ids=input_ids,
  407. attention_mask=attention_mask,
  408. position_ids=position_ids,
  409. past_key_values=past_key_values,
  410. inputs_embeds=inputs_embeds,
  411. use_cache=use_cache,
  412. **kwargs,
  413. )
  414. hidden_states = outputs.last_hidden_state
  415. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  416. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  417. logits = self.lm_head(hidden_states[:, slice_indices, :]).contiguous()
  418. loss = None
  419. if labels is not None:
  420. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  421. return CausalLMOutputWithPast(
  422. loss=loss,
  423. logits=logits,
  424. past_key_values=outputs.past_key_values,
  425. hidden_states=outputs.hidden_states,
  426. attentions=outputs.attentions,
  427. )
  428. @auto_docstring(
  429. custom_intro="""
  430. The OPT Model transformer with a sequence classification head on top (linear layer).
  431. [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  432. (e.g. GPT-2) do.
  433. Since it does classification on the last token, it requires to know the position of the last token. If a
  434. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  435. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  436. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  437. each row of the batch).
  438. """
  439. )
  440. class OPTForSequenceClassification(OPTPreTrainedModel):
  441. def __init__(self, config: OPTConfig):
  442. super().__init__(config)
  443. self.num_labels = config.num_labels
  444. self.model = OPTModel(config)
  445. self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
  446. # Initialize weights and apply final processing
  447. self.post_init()
  448. @can_return_tuple
  449. @auto_docstring
  450. def forward(
  451. self,
  452. input_ids: torch.LongTensor | None = None,
  453. attention_mask: torch.FloatTensor | None = None,
  454. past_key_values: Cache | None = None,
  455. inputs_embeds: torch.FloatTensor | None = None,
  456. labels: torch.LongTensor | None = None,
  457. use_cache: bool | None = None,
  458. position_ids: torch.LongTensor | None = None,
  459. **kwargs: Unpack[TransformersKwargs],
  460. ) -> tuple | SequenceClassifierOutputWithPast:
  461. r"""
  462. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  463. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  464. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  465. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  466. """
  467. transformer_outputs: BaseModelOutputWithPast = self.model(
  468. input_ids,
  469. past_key_values=past_key_values,
  470. attention_mask=attention_mask,
  471. position_ids=position_ids,
  472. inputs_embeds=inputs_embeds,
  473. use_cache=use_cache,
  474. **kwargs,
  475. )
  476. hidden_states = transformer_outputs.last_hidden_state
  477. logits = self.score(hidden_states)
  478. if input_ids is not None:
  479. batch_size, sequence_length = input_ids.shape[:2]
  480. else:
  481. batch_size, sequence_length = inputs_embeds.shape[:2]
  482. if self.config.pad_token_id is None and batch_size != 1:
  483. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  484. if self.config.pad_token_id is None:
  485. last_non_pad_token = -1
  486. elif input_ids is not None:
  487. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  488. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  489. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  490. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  491. else:
  492. last_non_pad_token = -1
  493. logger.warning_once(
  494. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  495. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  496. )
  497. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  498. loss = None
  499. if labels is not None:
  500. if self.config.problem_type is None:
  501. if self.num_labels == 1:
  502. self.config.problem_type = "regression"
  503. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  504. self.config.problem_type = "single_label_classification"
  505. else:
  506. self.config.problem_type = "multi_label_classification"
  507. if self.config.problem_type == "regression":
  508. loss_fct = MSELoss()
  509. if self.num_labels == 1:
  510. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  511. else:
  512. loss = loss_fct(pooled_logits, labels)
  513. elif self.config.problem_type == "single_label_classification":
  514. loss_fct = CrossEntropyLoss()
  515. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  516. elif self.config.problem_type == "multi_label_classification":
  517. loss_fct = BCEWithLogitsLoss()
  518. loss = loss_fct(pooled_logits, labels)
  519. return SequenceClassifierOutputWithPast(
  520. loss=loss,
  521. logits=pooled_logits,
  522. past_key_values=transformer_outputs.past_key_values,
  523. hidden_states=transformer_outputs.hidden_states,
  524. attentions=transformer_outputs.attentions,
  525. )
  526. def get_input_embeddings(self):
  527. return self.model.decoder.embed_tokens
  528. def set_input_embeddings(self, value):
  529. self.model.decoder.embed_tokens = value
  530. @auto_docstring
  531. class OPTForQuestionAnswering(OPTPreTrainedModel):
  532. def __init__(self, config: OPTConfig):
  533. super().__init__(config)
  534. self.model = OPTModel(config)
  535. self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2)
  536. # Initialize weights and apply final processing
  537. self.post_init()
  538. @can_return_tuple
  539. @auto_docstring
  540. def forward(
  541. self,
  542. input_ids: torch.LongTensor | None = None,
  543. attention_mask: torch.FloatTensor | None = None,
  544. past_key_values: Cache | None = None,
  545. inputs_embeds: torch.FloatTensor | None = None,
  546. start_positions: torch.LongTensor | None = None,
  547. end_positions: torch.LongTensor | None = None,
  548. use_cache: bool | None = None,
  549. position_ids: torch.LongTensor | None = None,
  550. **kwargs: Unpack[TransformersKwargs],
  551. ) -> tuple | QuestionAnsweringModelOutput:
  552. r"""
  553. Example:
  554. ```python
  555. >>> from transformers import AutoTokenizer, OPTForQuestionAnswering
  556. >>> import torch
  557. >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
  558. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
  559. >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
  560. >>> # so the head will be randomly initialized, hence the predictions will be random
  561. >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
  562. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  563. >>> inputs = tokenizer(question, text, return_tensors="pt")
  564. >>> with torch.no_grad():
  565. ... outputs = model(**inputs)
  566. >>> answer_start_index = outputs.start_logits.argmax()
  567. >>> answer_end_index = outputs.end_logits.argmax()
  568. >>> answer_offset = len(tokenizer(question)[0])
  569. >>> predict_answer_tokens = inputs.input_ids[
  570. ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
  571. ... ]
  572. >>> predicted = tokenizer.decode(predict_answer_tokens)
  573. >>> predicted
  574. ' a nice puppet'
  575. ```"""
  576. transformer_outputs: BaseModelOutputWithPast = self.model(
  577. input_ids,
  578. past_key_values=past_key_values,
  579. attention_mask=attention_mask,
  580. position_ids=position_ids,
  581. inputs_embeds=inputs_embeds,
  582. use_cache=use_cache,
  583. **kwargs,
  584. )
  585. hidden_states = transformer_outputs.last_hidden_state
  586. logits = self.qa_outputs(hidden_states)
  587. start_logits, end_logits = logits.split(1, dim=-1)
  588. start_logits = start_logits.squeeze(-1).contiguous()
  589. end_logits = end_logits.squeeze(-1).contiguous()
  590. total_loss = None
  591. if start_positions is not None and end_positions is not None:
  592. if len(start_positions.size()) > 1:
  593. start_positions = start_positions.squeeze(-1)
  594. if len(end_positions.size()) > 1:
  595. end_positions = end_positions.squeeze(-1)
  596. ignored_index = start_logits.size(1)
  597. start_positions = start_positions.clamp(0, ignored_index).to(logits.device)
  598. end_positions = end_positions.clamp(0, ignored_index).to(logits.device)
  599. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  600. start_loss = loss_fct(start_logits, start_positions)
  601. end_loss = loss_fct(end_logits, end_positions)
  602. total_loss = (start_loss + end_loss) / 2
  603. return QuestionAnsweringModelOutput(
  604. loss=total_loss,
  605. start_logits=start_logits,
  606. end_logits=end_logits,
  607. hidden_states=transformer_outputs.hidden_states,
  608. attentions=transformer_outputs.attentions,
  609. )
  610. def get_input_embeddings(self):
  611. return self.model.decoder.embed_tokens
  612. def set_input_embeddings(self, value):
  613. self.model.decoder.embed_tokens = value
  614. __all__ = [
  615. "OPTForCausalLM",
  616. "OPTModel",
  617. "OPTPreTrainedModel",
  618. "OPTForSequenceClassification",
  619. "OPTForQuestionAnswering",
  620. ]