modeling_mpt.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761
  1. # Copyright 2023 HuggingFace Inc. team and MosaicML NLP team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch MPT model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
  19. from torch.nn import functional as F
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithPastAndCrossAttentions,
  26. CausalLMOutputWithCrossAttentions,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutputWithPast,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...utils import auto_docstring, logging
  33. from .configuration_mpt import MptConfig
  34. logger = logging.get_logger(__name__)
  35. def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None):
  36. r"""
  37. Link to paper: https://huggingface.co/papers/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
  38. relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
  39. the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
  40. https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
  41. """
  42. alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
  43. num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))
  44. base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.int64, device=device).float()
  45. base = base * (alibi_bias_max / num_heads_power_of_2)
  46. slopes = 1.0 / torch.pow(2, base)
  47. slopes = slopes.view(1, num_heads_power_of_2, 1, 1)
  48. if num_heads_power_of_2 != num_heads:
  49. slopes = torch.concat([slopes[:, 1::2, ...], slopes[:, ::2, ...]], dim=1)[:, :num_heads, ...]
  50. alibi = alibi * slopes
  51. return alibi.squeeze(0)
  52. class MptAttention(nn.Module):
  53. """Multi-head self attention.
  54. Using torch or triton attention implementation enables user to also use additive bias.
  55. """
  56. def __init__(self, config: MptConfig, layer_idx: int | None = None):
  57. super().__init__()
  58. self.hidden_size = config.hidden_size
  59. self.n_heads = config.n_heads
  60. self.max_seq_length = config.max_seq_len
  61. self.head_dim = self.hidden_size // self.n_heads
  62. self.softmax_scale = config.attn_config.softmax_scale
  63. if self.softmax_scale is None:
  64. self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads)
  65. self.attn_dropout_p = config.attn_config.attn_pdrop
  66. self.clip_qkv = config.attn_config.clip_qkv
  67. self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
  68. self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  69. self.layer_idx = layer_idx
  70. def forward(
  71. self,
  72. hidden_states: torch.Tensor,
  73. position_bias: torch.Tensor,
  74. past_key_values: Cache | None = None,
  75. attention_mask: torch.Tensor | None = None,
  76. **kwargs,
  77. ):
  78. batch_size, seq_length = hidden_states.shape[:2]
  79. mixed_qkv = self.Wqkv(hidden_states)
  80. if self.clip_qkv:
  81. mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
  82. query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
  83. query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
  84. key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
  85. value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
  86. if past_key_values is not None:
  87. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  88. attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
  89. query_length = seq_length if past_key_values is None else seq_length + past_key_values.get_seq_length()
  90. if position_bias is not None:
  91. if len(position_bias.shape) != 3:
  92. raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}")
  93. key_length = key_states.shape[-2]
  94. position_bias_query_index = max(0, position_bias.size(1) - query_length)
  95. position_bias_key_index = max(0, position_bias.size(2) - key_length)
  96. position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:]
  97. attention_scores = attention_scores + position_bias
  98. if attention_mask is not None:
  99. attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min)
  100. # (batch_size, n_heads, seq_length, key_length)
  101. attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).to(value_states.dtype)
  102. attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training)
  103. context_states = torch.matmul(attn_weights, value_states)
  104. context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
  105. attn_output = self.out_proj(context_states)
  106. return attn_output, attn_weights
  107. class MptMLP(nn.Module):
  108. def __init__(self, config: MptConfig):
  109. super().__init__()
  110. hidden_size = config.hidden_size
  111. self.up_proj = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
  112. self.act = nn.GELU(approximate="none")
  113. self.down_proj = nn.Linear(4 * hidden_size, hidden_size, bias=False)
  114. self.hidden_dropout = config.attn_config.attn_pdrop
  115. def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
  116. hidden_states = self.act(self.up_proj(hidden_states))
  117. intermediate_output = self.down_proj(hidden_states)
  118. output = F.dropout(intermediate_output, p=self.hidden_dropout, training=self.training)
  119. output = output + residual
  120. return output
  121. class MptBlock(GradientCheckpointingLayer):
  122. def __init__(self, config: MptConfig, layer_idx: int | None = None):
  123. super().__init__()
  124. hidden_size = config.hidden_size
  125. self.norm_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  126. # backward compatibility with weights on the Hub
  127. self.norm_1.bias = None
  128. self.num_heads = config.n_heads
  129. self.attn = MptAttention(config, layer_idx)
  130. self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  131. # backward compatibility with weights on the Hub
  132. self.norm_2.bias = None
  133. self.ffn = MptMLP(config)
  134. self.dropout_rate = config.attn_config.attn_pdrop
  135. self.resid_attn_dropout = nn.Dropout(self.dropout_rate)
  136. def forward(
  137. self,
  138. hidden_states: torch.Tensor,
  139. position_bias: torch.Tensor,
  140. attention_mask: torch.Tensor,
  141. layer_past: Cache | None = None,
  142. use_cache: bool = False,
  143. output_attentions: bool = False,
  144. **kwargs,
  145. ):
  146. # hidden_states: [batch_size, seq_length, hidden_size]
  147. # Layer norm at the beginning of the transformer layer.
  148. layernorm_output = self.norm_1(hidden_states)
  149. residual = hidden_states
  150. # Self attention.
  151. attn_outputs, attn_weights = self.attn(
  152. layernorm_output,
  153. position_bias=position_bias,
  154. attention_mask=attention_mask,
  155. past_key_values=layer_past,
  156. )
  157. hidden_states = self.resid_attn_dropout(attn_outputs) + residual
  158. layernorm_output = self.norm_2(hidden_states)
  159. # Get residual
  160. residual = hidden_states
  161. # MLP.
  162. output = self.ffn(layernorm_output, residual)
  163. return output, attn_weights
  164. @auto_docstring
  165. class MptPreTrainedModel(PreTrainedModel):
  166. config: MptConfig
  167. base_model_prefix = "transformer"
  168. supports_gradient_checkpointing = True
  169. _no_split_modules = ["MptBlock"]
  170. @auto_docstring
  171. class MptModel(MptPreTrainedModel):
  172. def __init__(self, config: MptConfig):
  173. super().__init__(config)
  174. self.hidden_size = config.hidden_size
  175. self.num_heads = config.n_heads
  176. # Embedding + LN Embedding
  177. self.wte = nn.Embedding(config.vocab_size, self.hidden_size)
  178. # Transformer blocks
  179. self.blocks = nn.ModuleList([MptBlock(config, layer_idx=i) for i in range(config.n_layers)])
  180. # Final Layer Norm
  181. self.norm_f = LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
  182. # backward compatibility with weights on the Hub
  183. self.norm_f.bias = None
  184. self.gradient_checkpointing = False
  185. # Initialize weights and apply final processing
  186. self.post_init()
  187. def get_input_embeddings(self):
  188. return self.wte
  189. def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None):
  190. return build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max, device)
  191. def set_input_embeddings(self, new_embeddings: torch.Tensor):
  192. self.wte = new_embeddings
  193. @auto_docstring
  194. def forward(
  195. self,
  196. input_ids: torch.LongTensor | None = None,
  197. past_key_values: Cache | None = None,
  198. attention_mask: torch.Tensor | None = None,
  199. inputs_embeds: torch.LongTensor | None = None,
  200. use_cache: bool | None = None,
  201. output_attentions: bool | None = None,
  202. output_hidden_states: bool | None = None,
  203. return_dict: bool | None = None,
  204. **kwargs, # NOOP kwargs, for now
  205. ) -> tuple[torch.Tensor, ...] | BaseModelOutputWithPastAndCrossAttentions:
  206. r"""
  207. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  208. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  209. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  210. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  211. `input_ids`.
  212. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  213. [`PreTrainedTokenizer.__call__`] for details.
  214. [What are input IDs?](../glossary#input-ids)
  215. """
  216. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  217. output_hidden_states = (
  218. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  219. )
  220. use_cache = use_cache if use_cache is not None else self.config.use_cache
  221. return_dict = return_dict if return_dict is not None else self.config.return_dict
  222. if input_ids is not None and inputs_embeds is not None:
  223. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  224. elif input_ids is not None:
  225. batch_size, seq_length = input_ids.shape
  226. elif inputs_embeds is not None:
  227. batch_size, seq_length, _ = inputs_embeds.shape
  228. else:
  229. raise ValueError("You have to specify either input_ids or inputs_embeds")
  230. if self.gradient_checkpointing and self.training:
  231. if use_cache:
  232. logger.warning_once(
  233. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  234. )
  235. use_cache = False
  236. if inputs_embeds is None:
  237. inputs_embeds = self.wte(input_ids)
  238. if use_cache and past_key_values is None:
  239. past_key_values = DynamicCache(config=self.config)
  240. hidden_states = inputs_embeds
  241. all_self_attentions = () if output_attentions else None
  242. all_hidden_states = () if output_hidden_states else None
  243. # Compute alibi tensor: check build_alibi_tensor documentation
  244. alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device)
  245. causal_mask = create_causal_mask(
  246. config=self.config,
  247. inputs_embeds=inputs_embeds,
  248. attention_mask=attention_mask,
  249. past_key_values=past_key_values,
  250. ).to(torch.bool)
  251. for block in self.blocks:
  252. if output_hidden_states:
  253. all_hidden_states = all_hidden_states + (hidden_states,)
  254. outputs = block(
  255. hidden_states,
  256. layer_past=past_key_values,
  257. attention_mask=causal_mask,
  258. use_cache=use_cache,
  259. output_attentions=output_attentions,
  260. position_bias=alibi,
  261. )
  262. hidden_states = outputs[0]
  263. if output_attentions:
  264. all_self_attentions = all_self_attentions + (outputs[1],)
  265. # Add last hidden state
  266. hidden_states = self.norm_f(hidden_states)
  267. if output_hidden_states:
  268. all_hidden_states = all_hidden_states + (hidden_states,)
  269. if not return_dict:
  270. return tuple(
  271. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  272. )
  273. return BaseModelOutputWithPastAndCrossAttentions(
  274. last_hidden_state=hidden_states,
  275. past_key_values=past_key_values,
  276. hidden_states=all_hidden_states,
  277. attentions=all_self_attentions,
  278. )
  279. @auto_docstring(
  280. custom_intro="""
  281. The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
  282. embeddings).
  283. """
  284. )
  285. class MptForCausalLM(MptPreTrainedModel, GenerationMixin):
  286. _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
  287. def __init__(self, config: MptConfig):
  288. super().__init__(config)
  289. self.transformer = MptModel(config)
  290. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  291. # Initialize weights and apply final processing
  292. self.post_init()
  293. def set_output_embeddings(self, new_embeddings: torch.Tensor):
  294. self.lm_head = new_embeddings
  295. @auto_docstring
  296. def forward(
  297. self,
  298. input_ids: torch.LongTensor | None = None,
  299. past_key_values: Cache | None = None,
  300. attention_mask: torch.Tensor | None = None,
  301. inputs_embeds: torch.Tensor | None = None,
  302. labels: torch.Tensor | None = None,
  303. use_cache: bool | None = None,
  304. output_attentions: bool | None = None,
  305. output_hidden_states: bool | None = None,
  306. return_dict: bool | None = None,
  307. logits_to_keep: int | torch.Tensor = 0,
  308. **kwargs,
  309. ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
  310. r"""
  311. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  312. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  313. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  314. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  315. `input_ids`.
  316. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  317. [`PreTrainedTokenizer.__call__`] for details.
  318. [What are input IDs?](../glossary#input-ids)
  319. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  320. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  321. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  322. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  323. """
  324. return_dict = return_dict if return_dict is not None else self.config.return_dict
  325. transformer_outputs = self.transformer(
  326. input_ids,
  327. past_key_values=past_key_values,
  328. attention_mask=attention_mask,
  329. inputs_embeds=inputs_embeds,
  330. use_cache=use_cache,
  331. output_attentions=output_attentions,
  332. output_hidden_states=output_hidden_states,
  333. return_dict=return_dict,
  334. )
  335. hidden_states = transformer_outputs[0]
  336. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  337. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  338. logits = self.lm_head(hidden_states[:, slice_indices, :])
  339. loss = None
  340. if labels is not None:
  341. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  342. if not return_dict:
  343. output = (logits,) + transformer_outputs[1:]
  344. return ((loss,) + output) if loss is not None else output
  345. return CausalLMOutputWithCrossAttentions(
  346. loss=loss,
  347. logits=logits,
  348. past_key_values=transformer_outputs.past_key_values,
  349. hidden_states=transformer_outputs.hidden_states,
  350. attentions=transformer_outputs.attentions,
  351. )
  352. @auto_docstring(
  353. custom_intro="""
  354. The MPT Model transformer with a sequence classification head on top (linear layer).
  355. [`MptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  356. (e.g. GPT-1) do.
  357. Since it does classification on the last token, it requires to know the position of the last token. If a
  358. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  359. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  360. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  361. each row of the batch).
  362. """
  363. )
  364. class MptForSequenceClassification(MptPreTrainedModel):
  365. def __init__(self, config: MptConfig):
  366. super().__init__(config)
  367. self.num_labels = config.num_labels
  368. self.transformer = MptModel(config)
  369. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  370. # Initialize weights and apply final processing
  371. self.post_init()
  372. def set_output_embeddings(self, new_embeddings: torch.Tensor):
  373. self.score = new_embeddings
  374. @auto_docstring
  375. def forward(
  376. self,
  377. input_ids: torch.LongTensor | None = None,
  378. past_key_values: Cache | None = None,
  379. attention_mask: torch.Tensor | None = None,
  380. inputs_embeds: torch.Tensor | None = None,
  381. labels: torch.Tensor | None = None,
  382. use_cache: bool | None = None,
  383. output_attentions: bool | None = None,
  384. output_hidden_states: bool | None = None,
  385. return_dict: bool | None = None,
  386. **kwargs,
  387. ) -> tuple[torch.Tensor] | SequenceClassifierOutputWithPast:
  388. r"""
  389. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  390. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  391. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  392. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  393. `input_ids`.
  394. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  395. [`PreTrainedTokenizer.__call__`] for details.
  396. [What are input IDs?](../glossary#input-ids)
  397. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  398. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  399. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  400. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  401. """
  402. return_dict = return_dict if return_dict is not None else self.config.return_dict
  403. transformer_outputs = self.transformer(
  404. input_ids,
  405. past_key_values=past_key_values,
  406. attention_mask=attention_mask,
  407. inputs_embeds=inputs_embeds,
  408. use_cache=use_cache,
  409. output_attentions=output_attentions,
  410. output_hidden_states=output_hidden_states,
  411. return_dict=return_dict,
  412. )
  413. hidden_states = transformer_outputs[0]
  414. logits = self.score(hidden_states)
  415. if input_ids is not None:
  416. batch_size = input_ids.shape[0]
  417. else:
  418. batch_size = inputs_embeds.shape[0]
  419. if self.config.pad_token_id is None and batch_size != 1:
  420. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  421. if self.config.pad_token_id is None:
  422. last_non_pad_token = -1
  423. elif input_ids is not None:
  424. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  425. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  426. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  427. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  428. else:
  429. last_non_pad_token = -1
  430. logger.warning_once(
  431. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  432. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  433. )
  434. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  435. loss = None
  436. if labels is not None:
  437. if self.config.problem_type is None:
  438. if self.num_labels == 1:
  439. self.config.problem_type = "regression"
  440. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  441. self.config.problem_type = "single_label_classification"
  442. else:
  443. self.config.problem_type = "multi_label_classification"
  444. if self.config.problem_type == "regression":
  445. loss_fct = MSELoss()
  446. if self.num_labels == 1:
  447. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  448. else:
  449. loss = loss_fct(pooled_logits, labels)
  450. elif self.config.problem_type == "single_label_classification":
  451. loss_fct = CrossEntropyLoss()
  452. loss = loss_fct(pooled_logits, labels)
  453. elif self.config.problem_type == "multi_label_classification":
  454. loss_fct = BCEWithLogitsLoss()
  455. loss = loss_fct(pooled_logits, labels)
  456. if not return_dict:
  457. output = (pooled_logits,) + transformer_outputs[1:]
  458. return ((loss,) + output) if loss is not None else output
  459. return SequenceClassifierOutputWithPast(
  460. loss=loss,
  461. logits=pooled_logits,
  462. past_key_values=transformer_outputs.past_key_values,
  463. hidden_states=transformer_outputs.hidden_states,
  464. attentions=transformer_outputs.attentions,
  465. )
  466. @auto_docstring
  467. class MptForTokenClassification(MptPreTrainedModel):
  468. def __init__(self, config: MptConfig):
  469. super().__init__(config)
  470. self.num_labels = config.num_labels
  471. self.transformer = MptModel(config)
  472. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  473. classifier_dropout = config.classifier_dropout
  474. elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
  475. classifier_dropout = config.hidden_dropout
  476. else:
  477. classifier_dropout = 0.1
  478. self.dropout = nn.Dropout(classifier_dropout)
  479. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  480. # Initialize weights and apply final processing
  481. self.post_init()
  482. @auto_docstring
  483. def forward(
  484. self,
  485. input_ids: torch.LongTensor | None = None,
  486. past_key_values: Cache | None = None,
  487. attention_mask: torch.Tensor | None = None,
  488. inputs_embeds: torch.Tensor | None = None,
  489. labels: torch.Tensor | None = None,
  490. use_cache: bool | None = None,
  491. output_attentions: bool | None = None,
  492. output_hidden_states: bool | None = None,
  493. return_dict: bool | None = None,
  494. **deprecated_arguments,
  495. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  496. r"""
  497. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  498. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  499. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  500. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  501. `input_ids`.
  502. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  503. [`PreTrainedTokenizer.__call__`] for details.
  504. [What are input IDs?](../glossary#input-ids)
  505. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  506. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  507. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  508. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  509. """
  510. return_dict = return_dict if return_dict is not None else self.config.return_dict
  511. transformer_outputs = self.transformer(
  512. input_ids,
  513. past_key_values=past_key_values,
  514. attention_mask=attention_mask,
  515. inputs_embeds=inputs_embeds,
  516. use_cache=use_cache,
  517. output_attentions=output_attentions,
  518. output_hidden_states=output_hidden_states,
  519. return_dict=return_dict,
  520. )
  521. hidden_states = transformer_outputs[0]
  522. hidden_states = self.dropout(hidden_states)
  523. logits = self.classifier(hidden_states)
  524. loss = None
  525. if labels is not None:
  526. # move labels to correct device
  527. labels = labels.to(logits.device)
  528. batch_size, seq_length = labels.shape
  529. loss_fct = CrossEntropyLoss()
  530. loss = loss_fct(
  531. logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
  532. )
  533. if not return_dict:
  534. output = (logits,) + transformer_outputs[2:]
  535. return ((loss,) + output) if loss is not None else output
  536. return TokenClassifierOutput(
  537. loss=loss,
  538. logits=logits,
  539. hidden_states=transformer_outputs.hidden_states,
  540. attentions=transformer_outputs.attentions,
  541. )
  542. @auto_docstring
  543. class MptForQuestionAnswering(MptPreTrainedModel):
  544. def __init__(self, config):
  545. super().__init__(config)
  546. self.transformer = MptModel(config)
  547. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  548. # Initialize weights and apply final processing
  549. self.post_init()
  550. @auto_docstring
  551. def forward(
  552. self,
  553. input_ids: torch.LongTensor | None = None,
  554. attention_mask: torch.FloatTensor | None = None,
  555. inputs_embeds: torch.FloatTensor | None = None,
  556. start_positions: torch.LongTensor | None = None,
  557. end_positions: torch.LongTensor | None = None,
  558. output_attentions: bool | None = None,
  559. output_hidden_states: bool | None = None,
  560. return_dict: bool | None = None,
  561. **kwargs,
  562. ) -> tuple | QuestionAnsweringModelOutput:
  563. r"""
  564. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  565. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  566. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  567. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  568. `input_ids`.
  569. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  570. [`PreTrainedTokenizer.__call__`] for details.
  571. [What are input IDs?](../glossary#input-ids)
  572. """
  573. return_dict = return_dict if return_dict is not None else self.config.return_dict
  574. outputs = self.transformer(
  575. input_ids,
  576. attention_mask=attention_mask,
  577. inputs_embeds=inputs_embeds,
  578. output_attentions=output_attentions,
  579. output_hidden_states=output_hidden_states,
  580. return_dict=return_dict,
  581. )
  582. sequence_output = outputs[0]
  583. logits = self.qa_outputs(sequence_output)
  584. start_logits, end_logits = logits.split(1, dim=-1)
  585. start_logits = start_logits.squeeze(-1).contiguous()
  586. end_logits = end_logits.squeeze(-1).contiguous()
  587. total_loss = None
  588. if start_positions is not None and end_positions is not None:
  589. # If we are on multi-GPU, split add a dimension
  590. if len(start_positions.size()) > 1:
  591. start_positions = start_positions.squeeze(-1)
  592. if len(end_positions.size()) > 1:
  593. end_positions = end_positions.squeeze(-1)
  594. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  595. ignored_index = start_logits.size(1)
  596. start_positions = start_positions.clamp(0, ignored_index)
  597. end_positions = end_positions.clamp(0, ignored_index)
  598. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  599. start_loss = loss_fct(start_logits, start_positions)
  600. end_loss = loss_fct(end_logits, end_positions)
  601. total_loss = (start_loss + end_loss) / 2
  602. if not return_dict:
  603. output = (start_logits, end_logits) + outputs[2:]
  604. return ((total_loss,) + output) if total_loss is not None else output
  605. return QuestionAnsweringModelOutput(
  606. loss=total_loss,
  607. start_logits=start_logits,
  608. end_logits=end_logits,
  609. hidden_states=outputs.hidden_states,
  610. attentions=outputs.attentions,
  611. )
  612. __all__ = [
  613. "MptForCausalLM",
  614. "MptModel",
  615. "MptPreTrainedModel",
  616. "MptForSequenceClassification",
  617. "MptForTokenClassification",
  618. "MptForQuestionAnswering",
  619. ]