modular_biogpt.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch BioGPT model."""
  15. import math
  16. import torch
  17. import torch.nn as nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithPastAndCrossAttentions,
  25. CausalLMOutputWithCrossAttentions,
  26. SequenceClassifierOutputWithPast,
  27. TokenClassifierOutput,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import (
  32. TransformersKwargs,
  33. auto_docstring,
  34. can_return_tuple,
  35. logger,
  36. )
  37. from ...utils.generic import merge_with_config_defaults
  38. from ...utils.output_capturing import capture_outputs
  39. from ..bart.modeling_bart import (
  40. BartAttention,
  41. BartDecoderLayer,
  42. BartScaledWordEmbedding,
  43. )
  44. from ..opt.modeling_opt import OPTLearnedPositionalEmbedding
  45. from .configuration_biogpt import BioGptConfig
  46. class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding):
  47. def forward(
  48. self,
  49. attention_mask: torch.LongTensor,
  50. past_key_values_length: int = 0,
  51. position_ids: torch.LongTensor | None = None,
  52. ):
  53. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  54. return super().forward(attention_mask, past_key_values_length, position_ids)
  55. class BioGptScaledWordEmbedding(BartScaledWordEmbedding):
  56. pass
  57. class BioGptAttention(BartAttention):
  58. pass
  59. class BioGptDecoderLayer(BartDecoderLayer):
  60. def __init__(self, config: BioGptConfig, layer_idx: int | None = None):
  61. super().__init__(config)
  62. self.embed_dim = config.hidden_size
  63. self.self_attn = BioGptAttention(
  64. embed_dim=self.embed_dim,
  65. num_heads=config.num_attention_heads,
  66. dropout=config.attention_probs_dropout_prob,
  67. is_decoder=True,
  68. is_causal=True,
  69. config=config,
  70. layer_idx=layer_idx,
  71. )
  72. self.dropout = config.hidden_dropout_prob
  73. self.activation_fn = ACT2FN[config.hidden_act]
  74. self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
  75. self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
  76. del self.encoder_attn
  77. del self.encoder_attn_layer_norm
  78. def forward(
  79. self,
  80. hidden_states: torch.Tensor,
  81. attention_mask: torch.Tensor | None = None,
  82. past_key_values: Cache | None = None,
  83. use_cache: bool | None = True,
  84. position_ids: torch.LongTensor | None = None,
  85. **kwargs: Unpack[TransformersKwargs],
  86. ) -> torch.Tensor:
  87. """
  88. Args:
  89. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  90. attention_mask (`torch.FloatTensor`): attention mask of size
  91. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  92. past_key_values (`Cache`): cached past key and value projection states
  93. """
  94. residual = hidden_states
  95. hidden_states = self.self_attn_layer_norm(hidden_states)
  96. # Self Attention
  97. hidden_states, _ = self.self_attn(
  98. hidden_states=hidden_states,
  99. past_key_values=past_key_values,
  100. attention_mask=attention_mask,
  101. position_ids=position_ids,
  102. **kwargs,
  103. )
  104. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  105. hidden_states = residual + hidden_states
  106. # Fully Connected
  107. residual = hidden_states
  108. hidden_states = self.final_layer_norm(hidden_states)
  109. hidden_states = self.fc1(hidden_states)
  110. hidden_states = self.activation_fn(hidden_states)
  111. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  112. hidden_states = self.fc2(hidden_states)
  113. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  114. hidden_states = residual + hidden_states
  115. return hidden_states
  116. @auto_docstring
  117. class BioGptPreTrainedModel(PreTrainedModel):
  118. config: BioGptConfig
  119. base_model_prefix = "biogpt"
  120. supports_gradient_checkpointing = True
  121. _supports_flash_attn = True
  122. _supports_sdpa = True
  123. _supports_flex_attn = True
  124. _can_compile_fullgraph = True
  125. _can_record_outputs = {
  126. "hidden_states": BioGptDecoderLayer,
  127. "attentions": BioGptAttention,
  128. }
  129. @auto_docstring
  130. class BioGptModel(BioGptPreTrainedModel):
  131. def __init__(self, config: BioGptConfig):
  132. super().__init__(config)
  133. self.config = config
  134. self.layerdrop = config.layerdrop
  135. self.dropout = config.hidden_dropout_prob
  136. self.embed_dim = config.hidden_size
  137. self.padding_idx = config.pad_token_id
  138. embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
  139. self.embed_tokens = BioGptScaledWordEmbedding(
  140. config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
  141. )
  142. self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
  143. self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  144. self.layer_norm = nn.LayerNorm(self.embed_dim)
  145. self.gradient_checkpointing = False
  146. # Initialize weights and apply final processing
  147. self.post_init()
  148. @merge_with_config_defaults
  149. @capture_outputs
  150. @auto_docstring
  151. def forward(
  152. self,
  153. input_ids: torch.LongTensor | None = None,
  154. attention_mask: torch.FloatTensor | None = None,
  155. inputs_embeds: torch.FloatTensor | None = None,
  156. past_key_values: Cache | None = None,
  157. use_cache: bool | None = None,
  158. position_ids: torch.LongTensor | None = None,
  159. **kwargs: Unpack[TransformersKwargs],
  160. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  161. if (input_ids is None) ^ (inputs_embeds is not None):
  162. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  163. if inputs_embeds is None:
  164. inputs_embeds = self.embed_tokens(input_ids)
  165. # initialize past_key_values
  166. if use_cache and past_key_values is None:
  167. past_key_values = DynamicCache(config=self.config)
  168. batch_size, seq_length = inputs_embeds.size()[:-1]
  169. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  170. if attention_mask is None:
  171. # required mask seq length can be calculated via length of past cache
  172. mask_seq_length = past_key_values_length + seq_length
  173. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  174. self_attn_cache = past_key_values
  175. causal_mask = create_causal_mask(
  176. config=self.config,
  177. input_embeds=inputs_embeds,
  178. attention_mask=attention_mask,
  179. past_key_values=self_attn_cache,
  180. )
  181. # embed positions
  182. if position_ids is None:
  183. position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
  184. position_ids = position_ids.unsqueeze(0)
  185. positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
  186. hidden_states = inputs_embeds + positions
  187. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  188. for idx, decoder_layer in enumerate(self.layers):
  189. if self.training:
  190. dropout_probability = torch.rand([])
  191. if dropout_probability < self.layerdrop:
  192. continue
  193. hidden_states = decoder_layer(
  194. hidden_states,
  195. attention_mask=causal_mask,
  196. past_key_values=past_key_values,
  197. use_cache=use_cache,
  198. position_ids=position_ids,
  199. **kwargs,
  200. )
  201. hidden_states = self.layer_norm(hidden_states)
  202. return BaseModelOutputWithPastAndCrossAttentions(
  203. last_hidden_state=hidden_states,
  204. past_key_values=past_key_values,
  205. )
  206. @auto_docstring(
  207. custom_intro="""
  208. BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
  209. """
  210. )
  211. class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
  212. _tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
  213. def __init__(self, config):
  214. super().__init__(config)
  215. self.biogpt = BioGptModel(config)
  216. self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  217. # Initialize weights and apply final processing
  218. self.post_init()
  219. def get_output_embeddings(self):
  220. return self.output_projection
  221. def set_output_embeddings(self, new_embeddings):
  222. self.output_projection = new_embeddings
  223. @can_return_tuple
  224. @auto_docstring
  225. def forward(
  226. self,
  227. input_ids: torch.LongTensor | None = None,
  228. attention_mask: torch.FloatTensor | None = None,
  229. inputs_embeds: torch.FloatTensor | None = None,
  230. past_key_values: Cache | None = None,
  231. labels: torch.LongTensor | None = None,
  232. use_cache: bool | None = None,
  233. position_ids: torch.LongTensor | None = None,
  234. logits_to_keep: int | torch.Tensor = 0,
  235. **kwargs: Unpack[TransformersKwargs],
  236. ) -> tuple | CausalLMOutputWithCrossAttentions:
  237. r"""
  238. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  239. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  240. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  241. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  242. """
  243. outputs = self.biogpt(
  244. input_ids,
  245. attention_mask=attention_mask,
  246. inputs_embeds=inputs_embeds,
  247. past_key_values=past_key_values,
  248. use_cache=use_cache,
  249. position_ids=position_ids,
  250. **kwargs,
  251. )
  252. hidden_states = outputs[0]
  253. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  254. logits = self.output_projection(hidden_states[:, slice_indices, :])
  255. loss = None
  256. if labels is not None:
  257. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  258. return CausalLMOutputWithCrossAttentions(
  259. loss=loss,
  260. logits=logits,
  261. past_key_values=outputs.past_key_values,
  262. hidden_states=outputs.hidden_states,
  263. attentions=outputs.attentions,
  264. cross_attentions=outputs.cross_attentions,
  265. )
  266. @auto_docstring
  267. class BioGptForTokenClassification(BioGptPreTrainedModel):
  268. def __init__(self, config):
  269. super().__init__(config)
  270. self.num_labels = config.num_labels
  271. self.biogpt = BioGptModel(config)
  272. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  273. classifier_dropout = config.classifier_dropout
  274. else:
  275. classifier_dropout = config.hidden_dropout_prob
  276. self.dropout = nn.Dropout(classifier_dropout)
  277. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  278. self.post_init()
  279. @can_return_tuple
  280. @auto_docstring
  281. def forward(
  282. self,
  283. input_ids: torch.LongTensor | None = None,
  284. token_type_ids: torch.LongTensor | None = None,
  285. attention_mask: torch.FloatTensor | None = None,
  286. past_key_values: Cache | None = None,
  287. inputs_embeds: torch.FloatTensor | None = None,
  288. labels: torch.LongTensor | None = None,
  289. use_cache: bool | None = None,
  290. position_ids: torch.LongTensor | None = None,
  291. **kwargs,
  292. ) -> tuple | TokenClassifierOutput:
  293. r"""
  294. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  295. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  296. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  297. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  298. """
  299. transformer_outputs = self.biogpt(
  300. input_ids,
  301. past_key_values=past_key_values,
  302. attention_mask=attention_mask,
  303. inputs_embeds=inputs_embeds,
  304. use_cache=use_cache,
  305. position_ids=position_ids,
  306. **kwargs,
  307. )
  308. hidden_states = transformer_outputs[0]
  309. hidden_states = self.dropout(hidden_states)
  310. logits = self.classifier(hidden_states)
  311. loss = None
  312. if labels is not None:
  313. loss_fct = CrossEntropyLoss()
  314. if attention_mask is not None:
  315. active_loss = attention_mask.view(-1) == 1
  316. active_logits = logits.view(-1, self.num_labels)
  317. active_labels = torch.where(
  318. active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
  319. )
  320. loss = loss_fct(active_logits, active_labels)
  321. else:
  322. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  323. return TokenClassifierOutput(
  324. loss=loss,
  325. logits=logits,
  326. hidden_states=transformer_outputs.hidden_states,
  327. attentions=transformer_outputs.attentions,
  328. )
  329. @auto_docstring(
  330. custom_intro="""
  331. The BioGpt Model transformer with a sequence classification head on top (linear layer).
  332. [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  333. (e.g. GPT-2) do.
  334. Since it does classification on the last token, it is required to know the position of the last token. If a
  335. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  336. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  337. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  338. each row of the batch).
  339. """
  340. )
  341. class BioGptForSequenceClassification(BioGptPreTrainedModel):
  342. def __init__(self, config: BioGptConfig):
  343. super().__init__(config)
  344. self.num_labels = config.num_labels
  345. self.biogpt = BioGptModel(config)
  346. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  347. # Initialize weights and apply final processing
  348. self.post_init()
  349. @can_return_tuple
  350. @auto_docstring
  351. def forward(
  352. self,
  353. input_ids: torch.LongTensor | None = None,
  354. attention_mask: torch.FloatTensor | None = None,
  355. past_key_values: Cache | None = None,
  356. inputs_embeds: torch.FloatTensor | None = None,
  357. labels: torch.LongTensor | None = None,
  358. use_cache: bool | None = None,
  359. position_ids: torch.LongTensor | None = None,
  360. logits_to_keep: int | torch.Tensor = 0,
  361. **kwargs,
  362. ) -> tuple | SequenceClassifierOutputWithPast:
  363. r"""
  364. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  365. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  366. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  367. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  368. """
  369. transformer_outputs = self.biogpt(
  370. input_ids,
  371. past_key_values=past_key_values,
  372. attention_mask=attention_mask,
  373. inputs_embeds=inputs_embeds,
  374. use_cache=use_cache,
  375. position_ids=position_ids,
  376. **kwargs,
  377. )
  378. hidden_states = transformer_outputs[0]
  379. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  380. logits = self.score(hidden_states[:, slice_indices, :])
  381. if input_ids is not None:
  382. batch_size, sequence_length = input_ids.shape[:2]
  383. else:
  384. batch_size, sequence_length = inputs_embeds.shape[:2]
  385. if self.config.pad_token_id is None:
  386. sequence_length = -1
  387. else:
  388. if input_ids is not None:
  389. sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
  390. else:
  391. sequence_length = -1
  392. logger.warning_once(
  393. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  394. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  395. )
  396. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
  397. loss = None
  398. if labels is not None:
  399. if self.config.problem_type is None:
  400. if self.num_labels == 1:
  401. self.config.problem_type = "regression"
  402. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  403. self.config.problem_type = "single_label_classification"
  404. else:
  405. self.config.problem_type = "multi_label_classification"
  406. if self.config.problem_type == "regression":
  407. loss_fct = MSELoss()
  408. if self.num_labels == 1:
  409. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  410. else:
  411. loss = loss_fct(pooled_logits, labels)
  412. elif self.config.problem_type == "single_label_classification":
  413. loss_fct = CrossEntropyLoss()
  414. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  415. elif self.config.problem_type == "multi_label_classification":
  416. loss_fct = BCEWithLogitsLoss()
  417. loss = loss_fct(pooled_logits, labels)
  418. return SequenceClassifierOutputWithPast(
  419. loss=loss,
  420. logits=pooled_logits,
  421. past_key_values=transformer_outputs.past_key_values,
  422. hidden_states=transformer_outputs.hidden_states,
  423. attentions=transformer_outputs.attentions,
  424. )
  425. def get_input_embeddings(self):
  426. return self.biogpt.embed_tokens
  427. def set_input_embeddings(self, value):
  428. self.biogpt.embed_tokens = value
  429. __all__ = [
  430. "BioGptForCausalLM",
  431. "BioGptForTokenClassification",
  432. "BioGptForSequenceClassification",
  433. "BioGptModel",
  434. "BioGptPreTrainedModel",
  435. ]