modeling_biogpt.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_biogpt.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. import torch
  23. import torch.nn as nn
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  27. from ...generation import GenerationMixin
  28. from ...masking_utils import create_causal_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import (
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. CausalLMOutputWithCrossAttentions,
  34. SequenceClassifierOutputWithPast,
  35. TokenClassifierOutput,
  36. )
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  40. from ...utils.generic import merge_with_config_defaults
  41. from ...utils.output_capturing import capture_outputs
  42. from .configuration_biogpt import BioGptConfig
  43. logger = logging.get_logger(__name__)
  44. class BioGptLearnedPositionalEmbedding(nn.Embedding):
  45. """
  46. This module learns positional embeddings up to a fixed maximum size.
  47. """
  48. def __init__(self, num_embeddings: int, embedding_dim: int):
  49. # BIOGPT is set up so that if padding_idx is specified then offset the embedding ids by 2
  50. # and adjust num_embeddings appropriately. Other models don't have this hack
  51. self.offset = 2
  52. super().__init__(num_embeddings + self.offset, embedding_dim)
  53. def forward(
  54. self,
  55. attention_mask: torch.LongTensor,
  56. past_key_values_length: int = 0,
  57. position_ids: torch.LongTensor | None = None,
  58. ):
  59. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  60. if position_ids is None:
  61. position_ids = torch.cumsum(attention_mask, dim=1)
  62. position_ids = (position_ids * attention_mask - 1).long()
  63. # cut positions if `past_key_values_length` is > 0
  64. position_ids = position_ids[:, past_key_values_length:]
  65. return super().forward(position_ids + self.offset)
  66. class BioGptScaledWordEmbedding(nn.Embedding):
  67. """
  68. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  69. """
  70. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0):
  71. super().__init__(num_embeddings, embedding_dim, padding_idx)
  72. self.embed_scale = embed_scale
  73. def forward(self, input_ids: torch.Tensor):
  74. return super().forward(input_ids) * self.embed_scale
  75. def eager_attention_forward(
  76. module: nn.Module,
  77. query: torch.Tensor,
  78. key: torch.Tensor,
  79. value: torch.Tensor,
  80. attention_mask: torch.Tensor | None,
  81. scaling: float | None = None,
  82. dropout: float = 0.0,
  83. **kwargs: Unpack[TransformersKwargs],
  84. ):
  85. if scaling is None:
  86. scaling = query.size(-1) ** -0.5
  87. # Take the dot product between "query" and "key" to get the raw attention scores.
  88. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  89. if attention_mask is not None:
  90. attn_weights = attn_weights + attention_mask
  91. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  92. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  93. attn_output = torch.matmul(attn_weights, value)
  94. attn_output = attn_output.transpose(1, 2).contiguous()
  95. return attn_output, attn_weights
  96. class BioGptAttention(nn.Module):
  97. """Multi-headed attention from 'Attention Is All You Need' paper"""
  98. def __init__(
  99. self,
  100. embed_dim: int,
  101. num_heads: int,
  102. dropout: float = 0.0,
  103. is_decoder: bool = False,
  104. bias: bool = True,
  105. is_causal: bool = False,
  106. config: BioGptConfig | None = None,
  107. layer_idx: int | None = None,
  108. ):
  109. super().__init__()
  110. self.embed_dim = embed_dim
  111. self.num_heads = num_heads
  112. self.dropout = dropout
  113. self.head_dim = embed_dim // num_heads
  114. self.config = config
  115. if (self.head_dim * num_heads) != self.embed_dim:
  116. raise ValueError(
  117. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  118. f" and `num_heads`: {num_heads})."
  119. )
  120. self.scaling = self.head_dim**-0.5
  121. self.is_decoder = is_decoder
  122. self.is_causal = is_causal
  123. self.layer_idx = layer_idx
  124. if layer_idx is None and self.is_decoder:
  125. logger.warning_once(
  126. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  127. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  128. "when creating this class."
  129. )
  130. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  131. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  132. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  133. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  134. def forward(
  135. self,
  136. hidden_states: torch.Tensor,
  137. key_value_states: torch.Tensor | None = None,
  138. past_key_values: Cache | None = None,
  139. attention_mask: torch.Tensor | None = None,
  140. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  141. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  142. **kwargs: Unpack[FlashAttentionKwargs],
  143. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  144. """Input shape: Batch x Time x Channel"""
  145. # if key_value_states are provided this layer is used as a cross-attention layer
  146. # for the decoder
  147. is_cross_attention = key_value_states is not None
  148. # determine input shapes
  149. input_shape = hidden_states.shape[:-1]
  150. hidden_shape = (*input_shape, -1, self.head_dim)
  151. # get query proj
  152. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  153. is_updated = False
  154. if past_key_values is not None:
  155. if isinstance(past_key_values, EncoderDecoderCache):
  156. is_updated = past_key_values.is_updated.get(self.layer_idx)
  157. if is_cross_attention:
  158. # after the first generated id, we can subsequently re-use all key/value_states from cache
  159. curr_past_key_values = past_key_values.cross_attention_cache
  160. else:
  161. curr_past_key_values = past_key_values.self_attention_cache
  162. else:
  163. curr_past_key_values = past_key_values
  164. current_states = key_value_states if is_cross_attention else hidden_states
  165. if is_cross_attention and past_key_values is not None and is_updated:
  166. # reuse k,v, cross_attentions
  167. key_states = curr_past_key_values.layers[self.layer_idx].keys
  168. value_states = curr_past_key_values.layers[self.layer_idx].values
  169. else:
  170. key_states = self.k_proj(current_states)
  171. value_states = self.v_proj(current_states)
  172. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  173. key_states = key_states.view(kv_shape).transpose(1, 2)
  174. value_states = value_states.view(kv_shape).transpose(1, 2)
  175. if past_key_values is not None:
  176. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  177. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  178. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  179. past_key_values.is_updated[self.layer_idx] = True
  180. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  181. self.config._attn_implementation, eager_attention_forward
  182. )
  183. attn_output, attn_weights = attention_interface(
  184. self,
  185. query_states,
  186. key_states,
  187. value_states,
  188. attention_mask,
  189. dropout=0.0 if not self.training else self.dropout,
  190. scaling=self.scaling,
  191. **kwargs,
  192. )
  193. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  194. attn_output = self.out_proj(attn_output)
  195. return attn_output, attn_weights
  196. class BioGptDecoderLayer(GradientCheckpointingLayer):
  197. def __init__(self, config: BioGptConfig, layer_idx: int | None = None):
  198. super().__init__()
  199. self.embed_dim = config.hidden_size
  200. self.self_attn = BioGptAttention(
  201. embed_dim=self.embed_dim,
  202. num_heads=config.num_attention_heads,
  203. dropout=config.attention_probs_dropout_prob,
  204. is_decoder=True,
  205. is_causal=True,
  206. config=config,
  207. layer_idx=layer_idx,
  208. )
  209. self.dropout = config.hidden_dropout_prob
  210. self.activation_fn = ACT2FN[config.hidden_act]
  211. self.activation_dropout = config.activation_dropout
  212. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  213. self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
  214. self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
  215. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  216. def forward(
  217. self,
  218. hidden_states: torch.Tensor,
  219. attention_mask: torch.Tensor | None = None,
  220. past_key_values: Cache | None = None,
  221. use_cache: bool | None = True,
  222. position_ids: torch.LongTensor | None = None,
  223. **kwargs: Unpack[TransformersKwargs],
  224. ) -> torch.Tensor:
  225. """
  226. Args:
  227. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  228. attention_mask (`torch.FloatTensor`): attention mask of size
  229. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  230. past_key_values (`Cache`): cached past key and value projection states
  231. """
  232. residual = hidden_states
  233. hidden_states = self.self_attn_layer_norm(hidden_states)
  234. # Self Attention
  235. hidden_states, _ = self.self_attn(
  236. hidden_states=hidden_states,
  237. past_key_values=past_key_values,
  238. attention_mask=attention_mask,
  239. position_ids=position_ids,
  240. **kwargs,
  241. )
  242. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  243. hidden_states = residual + hidden_states
  244. # Fully Connected
  245. residual = hidden_states
  246. hidden_states = self.final_layer_norm(hidden_states)
  247. hidden_states = self.fc1(hidden_states)
  248. hidden_states = self.activation_fn(hidden_states)
  249. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  250. hidden_states = self.fc2(hidden_states)
  251. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  252. hidden_states = residual + hidden_states
  253. return hidden_states
  254. @auto_docstring
  255. class BioGptPreTrainedModel(PreTrainedModel):
  256. config: BioGptConfig
  257. base_model_prefix = "biogpt"
  258. supports_gradient_checkpointing = True
  259. _supports_flash_attn = True
  260. _supports_sdpa = True
  261. _supports_flex_attn = True
  262. _can_compile_fullgraph = True
  263. _can_record_outputs = {
  264. "hidden_states": BioGptDecoderLayer,
  265. "attentions": BioGptAttention,
  266. }
  267. @auto_docstring
  268. class BioGptModel(BioGptPreTrainedModel):
  269. def __init__(self, config: BioGptConfig):
  270. super().__init__(config)
  271. self.config = config
  272. self.layerdrop = config.layerdrop
  273. self.dropout = config.hidden_dropout_prob
  274. self.embed_dim = config.hidden_size
  275. self.padding_idx = config.pad_token_id
  276. embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
  277. self.embed_tokens = BioGptScaledWordEmbedding(
  278. config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
  279. )
  280. self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
  281. self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  282. self.layer_norm = nn.LayerNorm(self.embed_dim)
  283. self.gradient_checkpointing = False
  284. # Initialize weights and apply final processing
  285. self.post_init()
  286. @merge_with_config_defaults
  287. @capture_outputs
  288. @auto_docstring
  289. def forward(
  290. self,
  291. input_ids: torch.LongTensor | None = None,
  292. attention_mask: torch.FloatTensor | None = None,
  293. inputs_embeds: torch.FloatTensor | None = None,
  294. past_key_values: Cache | None = None,
  295. use_cache: bool | None = None,
  296. position_ids: torch.LongTensor | None = None,
  297. **kwargs: Unpack[TransformersKwargs],
  298. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  299. if (input_ids is None) ^ (inputs_embeds is not None):
  300. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  301. if inputs_embeds is None:
  302. inputs_embeds = self.embed_tokens(input_ids)
  303. # initialize past_key_values
  304. if use_cache and past_key_values is None:
  305. past_key_values = DynamicCache(config=self.config)
  306. batch_size, seq_length = inputs_embeds.size()[:-1]
  307. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  308. if attention_mask is None:
  309. # required mask seq length can be calculated via length of past cache
  310. mask_seq_length = past_key_values_length + seq_length
  311. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  312. self_attn_cache = past_key_values
  313. causal_mask = create_causal_mask(
  314. config=self.config,
  315. input_embeds=inputs_embeds,
  316. attention_mask=attention_mask,
  317. past_key_values=self_attn_cache,
  318. )
  319. # embed positions
  320. if position_ids is None:
  321. position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
  322. position_ids = position_ids.unsqueeze(0)
  323. positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
  324. hidden_states = inputs_embeds + positions
  325. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  326. for idx, decoder_layer in enumerate(self.layers):
  327. if self.training:
  328. dropout_probability = torch.rand([])
  329. if dropout_probability < self.layerdrop:
  330. continue
  331. hidden_states = decoder_layer(
  332. hidden_states,
  333. attention_mask=causal_mask,
  334. past_key_values=past_key_values,
  335. use_cache=use_cache,
  336. position_ids=position_ids,
  337. **kwargs,
  338. )
  339. hidden_states = self.layer_norm(hidden_states)
  340. return BaseModelOutputWithPastAndCrossAttentions(
  341. last_hidden_state=hidden_states,
  342. past_key_values=past_key_values,
  343. )
  344. @auto_docstring(
  345. custom_intro="""
  346. BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
  347. """
  348. )
  349. class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
  350. _tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
  351. def __init__(self, config):
  352. super().__init__(config)
  353. self.biogpt = BioGptModel(config)
  354. self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  355. # Initialize weights and apply final processing
  356. self.post_init()
  357. def get_output_embeddings(self):
  358. return self.output_projection
  359. def set_output_embeddings(self, new_embeddings):
  360. self.output_projection = new_embeddings
  361. @can_return_tuple
  362. @auto_docstring
  363. def forward(
  364. self,
  365. input_ids: torch.LongTensor | None = None,
  366. attention_mask: torch.FloatTensor | None = None,
  367. inputs_embeds: torch.FloatTensor | None = None,
  368. past_key_values: Cache | None = None,
  369. labels: torch.LongTensor | None = None,
  370. use_cache: bool | None = None,
  371. position_ids: torch.LongTensor | None = None,
  372. logits_to_keep: int | torch.Tensor = 0,
  373. **kwargs: Unpack[TransformersKwargs],
  374. ) -> tuple | CausalLMOutputWithCrossAttentions:
  375. r"""
  376. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  377. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  378. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  379. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  380. """
  381. outputs = self.biogpt(
  382. input_ids,
  383. attention_mask=attention_mask,
  384. inputs_embeds=inputs_embeds,
  385. past_key_values=past_key_values,
  386. use_cache=use_cache,
  387. position_ids=position_ids,
  388. **kwargs,
  389. )
  390. hidden_states = outputs[0]
  391. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  392. logits = self.output_projection(hidden_states[:, slice_indices, :])
  393. loss = None
  394. if labels is not None:
  395. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  396. return CausalLMOutputWithCrossAttentions(
  397. loss=loss,
  398. logits=logits,
  399. past_key_values=outputs.past_key_values,
  400. hidden_states=outputs.hidden_states,
  401. attentions=outputs.attentions,
  402. cross_attentions=outputs.cross_attentions,
  403. )
  404. @auto_docstring
  405. class BioGptForTokenClassification(BioGptPreTrainedModel):
  406. def __init__(self, config):
  407. super().__init__(config)
  408. self.num_labels = config.num_labels
  409. self.biogpt = BioGptModel(config)
  410. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  411. classifier_dropout = config.classifier_dropout
  412. else:
  413. classifier_dropout = config.hidden_dropout_prob
  414. self.dropout = nn.Dropout(classifier_dropout)
  415. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  416. self.post_init()
  417. @can_return_tuple
  418. @auto_docstring
  419. def forward(
  420. self,
  421. input_ids: torch.LongTensor | None = None,
  422. token_type_ids: torch.LongTensor | None = None,
  423. attention_mask: torch.FloatTensor | None = None,
  424. past_key_values: Cache | None = None,
  425. inputs_embeds: torch.FloatTensor | None = None,
  426. labels: torch.LongTensor | None = None,
  427. use_cache: bool | None = None,
  428. position_ids: torch.LongTensor | None = None,
  429. **kwargs,
  430. ) -> tuple | TokenClassifierOutput:
  431. r"""
  432. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  433. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  434. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  435. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  436. """
  437. transformer_outputs = self.biogpt(
  438. input_ids,
  439. past_key_values=past_key_values,
  440. attention_mask=attention_mask,
  441. inputs_embeds=inputs_embeds,
  442. use_cache=use_cache,
  443. position_ids=position_ids,
  444. **kwargs,
  445. )
  446. hidden_states = transformer_outputs[0]
  447. hidden_states = self.dropout(hidden_states)
  448. logits = self.classifier(hidden_states)
  449. loss = None
  450. if labels is not None:
  451. loss_fct = CrossEntropyLoss()
  452. if attention_mask is not None:
  453. active_loss = attention_mask.view(-1) == 1
  454. active_logits = logits.view(-1, self.num_labels)
  455. active_labels = torch.where(
  456. active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
  457. )
  458. loss = loss_fct(active_logits, active_labels)
  459. else:
  460. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  461. return TokenClassifierOutput(
  462. loss=loss,
  463. logits=logits,
  464. hidden_states=transformer_outputs.hidden_states,
  465. attentions=transformer_outputs.attentions,
  466. )
  467. @auto_docstring(
  468. custom_intro="""
  469. The BioGpt Model transformer with a sequence classification head on top (linear layer).
  470. [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  471. (e.g. GPT-2) do.
  472. Since it does classification on the last token, it is required to know the position of the last token. If a
  473. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  474. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  475. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  476. each row of the batch).
  477. """
  478. )
  479. class BioGptForSequenceClassification(BioGptPreTrainedModel):
  480. def __init__(self, config: BioGptConfig):
  481. super().__init__(config)
  482. self.num_labels = config.num_labels
  483. self.biogpt = BioGptModel(config)
  484. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  485. # Initialize weights and apply final processing
  486. self.post_init()
  487. @can_return_tuple
  488. @auto_docstring
  489. def forward(
  490. self,
  491. input_ids: torch.LongTensor | None = None,
  492. attention_mask: torch.FloatTensor | None = None,
  493. past_key_values: Cache | None = None,
  494. inputs_embeds: torch.FloatTensor | None = None,
  495. labels: torch.LongTensor | None = None,
  496. use_cache: bool | None = None,
  497. position_ids: torch.LongTensor | None = None,
  498. logits_to_keep: int | torch.Tensor = 0,
  499. **kwargs,
  500. ) -> tuple | SequenceClassifierOutputWithPast:
  501. r"""
  502. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  503. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  504. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  505. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  506. """
  507. transformer_outputs = self.biogpt(
  508. input_ids,
  509. past_key_values=past_key_values,
  510. attention_mask=attention_mask,
  511. inputs_embeds=inputs_embeds,
  512. use_cache=use_cache,
  513. position_ids=position_ids,
  514. **kwargs,
  515. )
  516. hidden_states = transformer_outputs[0]
  517. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  518. logits = self.score(hidden_states[:, slice_indices, :])
  519. if input_ids is not None:
  520. batch_size, sequence_length = input_ids.shape[:2]
  521. else:
  522. batch_size, sequence_length = inputs_embeds.shape[:2]
  523. if self.config.pad_token_id is None:
  524. sequence_length = -1
  525. else:
  526. if input_ids is not None:
  527. sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
  528. else:
  529. sequence_length = -1
  530. logger.warning_once(
  531. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  532. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  533. )
  534. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
  535. loss = None
  536. if labels is not None:
  537. if self.config.problem_type is None:
  538. if self.num_labels == 1:
  539. self.config.problem_type = "regression"
  540. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  541. self.config.problem_type = "single_label_classification"
  542. else:
  543. self.config.problem_type = "multi_label_classification"
  544. if self.config.problem_type == "regression":
  545. loss_fct = MSELoss()
  546. if self.num_labels == 1:
  547. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  548. else:
  549. loss = loss_fct(pooled_logits, labels)
  550. elif self.config.problem_type == "single_label_classification":
  551. loss_fct = CrossEntropyLoss()
  552. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  553. elif self.config.problem_type == "multi_label_classification":
  554. loss_fct = BCEWithLogitsLoss()
  555. loss = loss_fct(pooled_logits, labels)
  556. return SequenceClassifierOutputWithPast(
  557. loss=loss,
  558. logits=pooled_logits,
  559. past_key_values=transformer_outputs.past_key_values,
  560. hidden_states=transformer_outputs.hidden_states,
  561. attentions=transformer_outputs.attentions,
  562. )
  563. def get_input_embeddings(self):
  564. return self.biogpt.embed_tokens
  565. def set_input_embeddings(self, value):
  566. self.biogpt.embed_tokens = value
  567. __all__ = [
  568. "BioGptForCausalLM",
  569. "BioGptForTokenClassification",
  570. "BioGptForSequenceClassification",
  571. "BioGptModel",
  572. "BioGptPreTrainedModel",
  573. ]