modeling_esm.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993
  1. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
  2. # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ESM model."""
  16. import math
  17. from collections.abc import Callable
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ... import initialization as init
  22. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithCrossAttentions,
  26. BaseModelOutputWithPoolingAndCrossAttentions,
  27. MaskedLMOutput,
  28. SequenceClassifierOutput,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  34. from ...utils.generic import merge_with_config_defaults
  35. from ...utils.output_capturing import OutputRecorder, capture_outputs
  36. from .configuration_esm import EsmConfig
  37. logger = logging.get_logger(__name__)
  38. def rotate_half(x):
  39. x1, x2 = x.chunk(2, dim=-1)
  40. return torch.cat((-x2, x1), dim=-1)
  41. def apply_rotary_pos_emb(x, cos, sin):
  42. cos = cos[:, :, : x.shape[-2], :]
  43. sin = sin[:, :, : x.shape[-2], :]
  44. return (x * cos) + (rotate_half(x) * sin)
  45. def gelu(x):
  46. """
  47. This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
  48. """
  49. return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
  50. def symmetrize(x):
  51. "Make layer symmetric in final two dimensions, used for contact prediction."
  52. return x + x.transpose(-1, -2)
  53. def average_product_correct(x):
  54. "Perform average product correct, used for contact prediction."
  55. a1 = x.sum(-1, keepdims=True)
  56. a2 = x.sum(-2, keepdims=True)
  57. a12 = x.sum((-1, -2), keepdims=True)
  58. avg = a1 * a2
  59. avg.div_(a12) # in-place to reduce memory
  60. normalized = x - avg
  61. return normalized
  62. class RotaryEmbedding(torch.nn.Module):
  63. """
  64. Rotary position embeddings based on those in
  65. [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
  66. matrices which depend on their relative positions.
  67. """
  68. inv_freq: torch.Tensor # fix linting for `register_buffer`
  69. def __init__(self, dim: int):
  70. super().__init__()
  71. self.dim = dim
  72. # Generate and save the inverse frequency buffer (non trainable)
  73. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
  74. self.register_buffer("inv_freq", inv_freq)
  75. self._seq_len_cached = None
  76. self._cos_cached = None
  77. self._sin_cached = None
  78. def _update_cos_sin_tables(self, x, seq_dimension=2):
  79. seq_len = x.shape[seq_dimension]
  80. # Reset the tables if the sequence length has changed,
  81. # or if we're on a new device (possibly due to tracing for instance)
  82. if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
  83. self._seq_len_cached = seq_len
  84. t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
  85. freqs = torch.outer(t, self.inv_freq)
  86. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  87. self._cos_cached = emb.cos()[None, None, :, :]
  88. self._sin_cached = emb.sin()[None, None, :, :]
  89. return self._cos_cached, self._sin_cached
  90. def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  91. self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
  92. return (
  93. apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
  94. apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
  95. )
  96. class EsmContactPredictionHead(nn.Module):
  97. """Performs symmetrization, apc, and computes a logistic regression on the output features"""
  98. def __init__(
  99. self,
  100. in_features: int,
  101. bias=True,
  102. eos_idx: int = 2,
  103. ):
  104. super().__init__()
  105. self.in_features = in_features
  106. self.eos_idx = eos_idx
  107. self.regression = nn.Linear(in_features, 1, bias)
  108. self.activation = nn.Sigmoid()
  109. def forward(self, tokens, attentions):
  110. # remove eos token attentions
  111. eos_mask = tokens.ne(self.eos_idx).to(attentions)
  112. eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
  113. attentions = attentions * eos_mask[:, None, None, :, :]
  114. attentions = attentions[..., :-1, :-1]
  115. # remove cls token attentions
  116. attentions = attentions[..., 1:, 1:]
  117. batch_size, layers, heads, seqlen, _ = attentions.size()
  118. attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
  119. # features: batch x channels x tokens x tokens (symmetric)
  120. attentions = attentions.to(
  121. self.regression.weight.device
  122. ) # attentions always float32, may need to convert to float16
  123. attentions = average_product_correct(symmetrize(attentions))
  124. attentions = attentions.permute(0, 2, 3, 1)
  125. return self.activation(self.regression(attentions).squeeze(3))
  126. class EsmEmbeddings(nn.Module):
  127. """
  128. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  129. """
  130. def __init__(self, config):
  131. super().__init__()
  132. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  133. if config.emb_layer_norm_before:
  134. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  135. else:
  136. self.layer_norm = None
  137. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  138. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  139. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  140. self.register_buffer(
  141. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  142. )
  143. self.padding_idx = config.pad_token_id
  144. if self.position_embedding_type == "absolute":
  145. self.position_embeddings = nn.Embedding(
  146. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  147. )
  148. self.token_dropout = config.token_dropout
  149. self.mask_token_id = config.mask_token_id
  150. def forward(
  151. self,
  152. input_ids=None,
  153. attention_mask=None,
  154. position_ids=None,
  155. inputs_embeds=None,
  156. ):
  157. if position_ids is None:
  158. if input_ids is not None:
  159. # Create the position ids from the input token ids. Any padded tokens remain padded.
  160. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
  161. else:
  162. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  163. if inputs_embeds is None:
  164. inputs_embeds = self.word_embeddings(input_ids)
  165. # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
  166. # embedding_scale factor here.
  167. embeddings = inputs_embeds
  168. # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
  169. # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
  170. # masked tokens are treated as if they were selected for input dropout and zeroed out.
  171. # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
  172. # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
  173. # This is analogous to the way that dropout layers scale down outputs during evaluation when not
  174. # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
  175. if self.token_dropout and input_ids is not None:
  176. embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
  177. mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
  178. src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
  179. mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
  180. embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
  181. embeddings.dtype
  182. )
  183. if self.position_embedding_type == "absolute":
  184. position_embeddings = self.position_embeddings(position_ids)
  185. embeddings = embeddings + position_embeddings
  186. if self.layer_norm is not None:
  187. embeddings = self.layer_norm(embeddings)
  188. if attention_mask is not None:
  189. embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
  190. # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
  191. # embeddings = self.dropout(embeddings)
  192. return embeddings
  193. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  194. """
  195. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  196. Args:
  197. inputs_embeds: torch.Tensor
  198. Returns: torch.Tensor
  199. """
  200. input_shape = inputs_embeds.size()[:-1]
  201. sequence_length = input_shape[1]
  202. position_ids = torch.arange(
  203. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  204. )
  205. return position_ids.unsqueeze(0).expand(input_shape)
  206. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  207. def eager_attention_forward(
  208. module: nn.Module,
  209. query: torch.Tensor,
  210. key: torch.Tensor,
  211. value: torch.Tensor,
  212. attention_mask: torch.Tensor | None,
  213. scaling: float | None = None,
  214. dropout: float = 0.0,
  215. **kwargs: Unpack[TransformersKwargs],
  216. ):
  217. if scaling is None:
  218. scaling = query.size(-1) ** -0.5
  219. # Take the dot product between "query" and "key" to get the raw attention scores.
  220. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  221. if attention_mask is not None:
  222. attn_weights = attn_weights + attention_mask
  223. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  224. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  225. attn_output = torch.matmul(attn_weights, value)
  226. attn_output = attn_output.transpose(1, 2).contiguous()
  227. return attn_output, attn_weights
  228. class EsmSelfAttention(nn.Module):
  229. def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
  230. super().__init__()
  231. self.config = config
  232. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  233. raise ValueError(
  234. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  235. f"heads ({config.num_attention_heads})"
  236. )
  237. self.num_attention_heads = config.num_attention_heads
  238. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  239. self.all_head_size = self.num_attention_heads * self.attention_head_size
  240. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  241. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  242. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  243. self.dropout = config.attention_probs_dropout_prob
  244. self.rotary_embeddings = None
  245. self.position_embedding_type = position_embedding_type or getattr(
  246. config, "position_embedding_type", "absolute"
  247. )
  248. if self.position_embedding_type == "rotary":
  249. self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
  250. self.scaling = 1.0 # For BC we apply scaling before RoPE
  251. self.is_decoder = config.is_decoder
  252. self.layer_idx = layer_idx
  253. self.is_causal = self.is_decoder and not is_cross_attention
  254. def forward(
  255. self,
  256. hidden_states: torch.Tensor,
  257. attention_mask: torch.FloatTensor | None = None,
  258. encoder_hidden_states: torch.FloatTensor | None = None,
  259. encoder_attention_mask: torch.FloatTensor | None = None,
  260. **kwargs: Unpack[TransformersKwargs],
  261. ) -> tuple[torch.Tensor]:
  262. input_shape = hidden_states.shape[:-1]
  263. hidden_shape = (*input_shape, -1, self.attention_head_size)
  264. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  265. is_cross_attention = encoder_hidden_states is not None
  266. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  267. attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  268. key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
  269. value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
  270. # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
  271. # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
  272. # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
  273. # ESM code and fix rotary embeddings.
  274. query_layer = query_layer * self.attention_head_size**-0.5
  275. if self.position_embedding_type == "rotary":
  276. query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
  277. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  278. self.config._attn_implementation, eager_attention_forward
  279. )
  280. attn_output, attn_weights = attention_interface(
  281. self,
  282. query_layer,
  283. key_layer,
  284. value_layer,
  285. attention_mask,
  286. dropout=0.0 if not self.training else self.dropout,
  287. scaling=self.scaling,
  288. **kwargs,
  289. )
  290. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  291. return attn_output, attn_weights
  292. class EsmSelfOutput(nn.Module):
  293. def __init__(self, config):
  294. super().__init__()
  295. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  296. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  297. def forward(self, hidden_states, input_tensor):
  298. hidden_states = self.dense(hidden_states)
  299. hidden_states = self.dropout(hidden_states)
  300. hidden_states = hidden_states + input_tensor
  301. return hidden_states
  302. class EsmAttention(nn.Module):
  303. def __init__(self, config, layer_idx=None, is_cross_attention=False):
  304. super().__init__()
  305. self.self = EsmSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention)
  306. self.output = EsmSelfOutput(config)
  307. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  308. def forward(
  309. self,
  310. hidden_states,
  311. attention_mask=None,
  312. encoder_hidden_states=None,
  313. encoder_attention_mask=None,
  314. **kwargs: Unpack[TransformersKwargs],
  315. ):
  316. hidden_states_ln = self.LayerNorm(hidden_states)
  317. attn_output, _ = self.self(
  318. hidden_states_ln,
  319. attention_mask=attention_mask,
  320. encoder_hidden_states=encoder_hidden_states,
  321. encoder_attention_mask=encoder_attention_mask,
  322. **kwargs,
  323. )
  324. attn_output = self.output(attn_output, hidden_states)
  325. return attn_output
  326. class EsmIntermediate(nn.Module):
  327. def __init__(self, config):
  328. super().__init__()
  329. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  330. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  331. hidden_states = self.dense(hidden_states)
  332. hidden_states = gelu(hidden_states)
  333. return hidden_states
  334. class EsmOutput(nn.Module):
  335. def __init__(self, config):
  336. super().__init__()
  337. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  338. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  339. def forward(self, hidden_states, input_tensor):
  340. hidden_states = self.dense(hidden_states)
  341. hidden_states = self.dropout(hidden_states)
  342. hidden_states = hidden_states + input_tensor
  343. return hidden_states
  344. class EsmLayer(GradientCheckpointingLayer):
  345. def __init__(self, config):
  346. super().__init__()
  347. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  348. self.seq_len_dim = 1
  349. self.attention = EsmAttention(config)
  350. self.is_decoder = config.is_decoder
  351. self.add_cross_attention = config.add_cross_attention
  352. if self.add_cross_attention:
  353. if not self.is_decoder:
  354. raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
  355. self.crossattention = EsmAttention(config, is_cross_attention=True)
  356. self.intermediate = EsmIntermediate(config)
  357. self.output = EsmOutput(config)
  358. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  359. def forward(
  360. self,
  361. hidden_states,
  362. attention_mask=None,
  363. encoder_hidden_states=None,
  364. encoder_attention_mask=None,
  365. **kwargs: Unpack[TransformersKwargs],
  366. ):
  367. attention_output = self.attention(
  368. hidden_states,
  369. attention_mask=attention_mask,
  370. **kwargs,
  371. )
  372. if self.is_decoder and encoder_hidden_states is not None:
  373. if not hasattr(self, "crossattention"):
  374. raise AttributeError(
  375. f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
  376. " with cross-attention layers by setting `config.add_cross_attention=True`"
  377. )
  378. attention_output = self.crossattention(
  379. attention_output,
  380. attention_mask=attention_mask,
  381. encoder_hidden_states=encoder_hidden_states,
  382. encoder_attention_mask=encoder_attention_mask,
  383. **kwargs,
  384. )
  385. layer_output = self.feed_forward_chunk(attention_output)
  386. return layer_output
  387. def feed_forward_chunk(self, attention_output):
  388. attention_output_ln = self.LayerNorm(attention_output)
  389. intermediate_output = self.intermediate(attention_output_ln)
  390. layer_output = self.output(intermediate_output, attention_output)
  391. return layer_output
  392. class EsmEncoder(nn.Module):
  393. def __init__(self, config):
  394. super().__init__()
  395. self.config = config
  396. self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
  397. self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  398. self.gradient_checkpointing = False
  399. @can_return_tuple
  400. def forward(
  401. self,
  402. hidden_states,
  403. attention_mask=None,
  404. encoder_hidden_states=None,
  405. encoder_attention_mask=None,
  406. **kwargs: Unpack[TransformersKwargs],
  407. ):
  408. for i, layer_module in enumerate(self.layer):
  409. hidden_states = layer_module(
  410. hidden_states,
  411. attention_mask=attention_mask,
  412. encoder_hidden_states=encoder_hidden_states,
  413. encoder_attention_mask=encoder_attention_mask,
  414. **kwargs,
  415. )
  416. if self.emb_layer_norm_after:
  417. hidden_states = self.emb_layer_norm_after(hidden_states)
  418. return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states)
  419. # Copied from transformers.models.bert.modeling_bert.BertPooler
  420. class EsmPooler(nn.Module):
  421. def __init__(self, config):
  422. super().__init__()
  423. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  424. self.activation = nn.Tanh()
  425. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  426. # We "pool" the model by simply taking the hidden state corresponding
  427. # to the first token.
  428. first_token_tensor = hidden_states[:, 0]
  429. pooled_output = self.dense(first_token_tensor)
  430. pooled_output = self.activation(pooled_output)
  431. return pooled_output
  432. @auto_docstring
  433. class EsmPreTrainedModel(PreTrainedModel):
  434. config: EsmConfig
  435. base_model_prefix = "esm"
  436. supports_gradient_checkpointing = True
  437. accepts_loss_kwargs = False
  438. _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
  439. _keys_to_ignore_on_load_unexpected = ["position_embeddings.weight"]
  440. _supports_flash_attn = True
  441. _supports_sdpa = True
  442. _supports_flex_attn = True
  443. _supports_attention_backend = True
  444. _can_record_outputs = {
  445. "hidden_states": EsmLayer,
  446. "attentions": [OutputRecorder(EsmSelfAttention, index=1, layer_name="attention")],
  447. "cross_attentions": [
  448. OutputRecorder(EsmSelfAttention, index=1, layer_name="crossattention"),
  449. ],
  450. }
  451. @torch.no_grad()
  452. def _init_weights(self, module):
  453. """Initialize the weights"""
  454. super()._init_weights(module)
  455. if isinstance(module, EsmLMHead):
  456. init.zeros_(module.bias)
  457. elif isinstance(module, EsmEmbeddings):
  458. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  459. elif isinstance(module, RotaryEmbedding):
  460. inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
  461. init.copy_(module.inv_freq, inv_freq)
  462. def get_output_embeddings(self):
  463. # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
  464. # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
  465. return None
  466. @auto_docstring
  467. class EsmModel(EsmPreTrainedModel):
  468. """
  469. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  470. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  471. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  472. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  473. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  474. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  475. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  476. """
  477. def __init__(self, config, add_pooling_layer=True):
  478. r"""
  479. add_pooling_layer (bool, *optional*, defaults to `True`):
  480. Whether to add a pooling layer
  481. """
  482. super().__init__(config)
  483. self.config = config
  484. self.embeddings = EsmEmbeddings(config)
  485. self.encoder = EsmEncoder(config)
  486. self.pooler = EsmPooler(config) if add_pooling_layer else None
  487. self.contact_head = EsmContactPredictionHead(
  488. in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
  489. )
  490. # Initialize weights and apply final processing
  491. self.post_init()
  492. def get_input_embeddings(self):
  493. return self.embeddings.word_embeddings
  494. def set_input_embeddings(self, value):
  495. self.embeddings.word_embeddings = value
  496. @merge_with_config_defaults
  497. @capture_outputs
  498. @auto_docstring
  499. def forward(
  500. self,
  501. input_ids: torch.Tensor | None = None,
  502. attention_mask: torch.Tensor | None = None,
  503. position_ids: torch.Tensor | None = None,
  504. inputs_embeds: torch.Tensor | None = None,
  505. encoder_hidden_states: torch.Tensor | None = None,
  506. encoder_attention_mask: torch.Tensor | None = None,
  507. **kwargs: Unpack[TransformersKwargs],
  508. ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  509. r"""
  510. input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`):
  511. Indices of input sequence tokens in the vocabulary.
  512. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  513. [`PreTrainedTokenizer.__call__`] for details.
  514. [What are input IDs?](../glossary#input-ids)
  515. position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*):
  516. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  517. config.max_position_embeddings - 1]`.
  518. [What are position IDs?](../glossary#position-ids)
  519. inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*):
  520. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  521. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  522. model's internal embedding lookup matrix.
  523. """
  524. if (input_ids is None) ^ (inputs_embeds is not None):
  525. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  526. if inputs_embeds is None:
  527. # Important, attention_mask must be passed to the embedding class
  528. # This effects how the token_dropout is calculated
  529. inputs_embeds = self.embeddings(
  530. input_ids=input_ids,
  531. attention_mask=attention_mask,
  532. position_ids=position_ids,
  533. )
  534. attention_mask, encoder_attention_mask = self._create_attention_masks(
  535. attention_mask=attention_mask,
  536. encoder_attention_mask=encoder_attention_mask,
  537. embedding_output=inputs_embeds,
  538. encoder_hidden_states=encoder_hidden_states,
  539. past_key_values=None,
  540. )
  541. encoder_outputs = self.encoder(
  542. inputs_embeds,
  543. attention_mask=attention_mask,
  544. encoder_hidden_states=encoder_hidden_states,
  545. encoder_attention_mask=encoder_attention_mask,
  546. **kwargs,
  547. )
  548. sequence_output = encoder_outputs[0]
  549. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  550. return BaseModelOutputWithPoolingAndCrossAttentions(
  551. last_hidden_state=sequence_output,
  552. pooler_output=pooled_output,
  553. )
  554. # Copied from transformers.models.bert.modeling_bert.BertModel._create_attention_masks
  555. def _create_attention_masks(
  556. self,
  557. attention_mask,
  558. encoder_attention_mask,
  559. embedding_output,
  560. encoder_hidden_states,
  561. past_key_values,
  562. ):
  563. if self.config.is_decoder:
  564. attention_mask = create_causal_mask(
  565. config=self.config,
  566. inputs_embeds=embedding_output,
  567. attention_mask=attention_mask,
  568. past_key_values=past_key_values,
  569. )
  570. else:
  571. attention_mask = create_bidirectional_mask(
  572. config=self.config,
  573. inputs_embeds=embedding_output,
  574. attention_mask=attention_mask,
  575. )
  576. if encoder_attention_mask is not None:
  577. encoder_attention_mask = create_bidirectional_mask(
  578. config=self.config,
  579. inputs_embeds=embedding_output,
  580. attention_mask=encoder_attention_mask,
  581. encoder_hidden_states=encoder_hidden_states,
  582. )
  583. return attention_mask, encoder_attention_mask
  584. def predict_contacts(self, tokens, attention_mask):
  585. attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
  586. attns = torch.stack(attns, dim=1) # Matches the original model layout
  587. # In the original model, attentions for padding tokens are completely zeroed out.
  588. # This makes no difference most of the time because the other tokens won't attend to them,
  589. # but it does for the contact prediction task, which takes attentions as input,
  590. # so we have to mimic that here.
  591. attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
  592. attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
  593. return self.contact_head(tokens, attns)
  594. @auto_docstring
  595. class EsmForMaskedLM(EsmPreTrainedModel):
  596. _tied_weights_keys = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"}
  597. def __init__(self, config):
  598. super().__init__(config)
  599. if config.is_decoder:
  600. logger.warning(
  601. "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
  602. "bi-directional self-attention."
  603. )
  604. self.esm = EsmModel(config, add_pooling_layer=False)
  605. self.lm_head = EsmLMHead(config)
  606. self.post_init()
  607. def get_output_embeddings(self):
  608. return self.lm_head.decoder
  609. def set_output_embeddings(self, new_embeddings):
  610. self.lm_head.decoder = new_embeddings
  611. @can_return_tuple
  612. @auto_docstring
  613. def forward(
  614. self,
  615. input_ids: torch.LongTensor | None = None,
  616. attention_mask: torch.Tensor | None = None,
  617. position_ids: torch.LongTensor | None = None,
  618. inputs_embeds: torch.FloatTensor | None = None,
  619. encoder_hidden_states: torch.FloatTensor | None = None,
  620. encoder_attention_mask: torch.Tensor | None = None,
  621. labels: torch.LongTensor | None = None,
  622. **kwargs: Unpack[TransformersKwargs],
  623. ) -> tuple | MaskedLMOutput:
  624. r"""
  625. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  626. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  627. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  628. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  629. """
  630. outputs = self.esm(
  631. input_ids,
  632. attention_mask=attention_mask,
  633. position_ids=position_ids,
  634. inputs_embeds=inputs_embeds,
  635. encoder_hidden_states=encoder_hidden_states,
  636. encoder_attention_mask=encoder_attention_mask,
  637. **kwargs,
  638. )
  639. sequence_output = outputs[0]
  640. prediction_scores = self.lm_head(sequence_output)
  641. masked_lm_loss = None
  642. if labels is not None:
  643. loss_fct = CrossEntropyLoss()
  644. labels = labels.to(prediction_scores.device)
  645. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  646. return MaskedLMOutput(
  647. loss=masked_lm_loss,
  648. logits=prediction_scores,
  649. hidden_states=outputs.hidden_states,
  650. attentions=outputs.attentions,
  651. )
  652. def predict_contacts(self, tokens, attention_mask):
  653. return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
  654. class EsmLMHead(nn.Module):
  655. """ESM Head for masked language modeling."""
  656. def __init__(self, config):
  657. super().__init__()
  658. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  659. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  660. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  661. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  662. def forward(self, features, **kwargs):
  663. x = self.dense(features)
  664. x = gelu(x)
  665. x = self.layer_norm(x)
  666. # project back to size of vocabulary with bias
  667. x = self.decoder(x) + self.bias
  668. return x
  669. @auto_docstring(
  670. custom_intro="""
  671. ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  672. output) e.g. for GLUE tasks.
  673. """
  674. )
  675. class EsmForSequenceClassification(EsmPreTrainedModel):
  676. def __init__(self, config):
  677. super().__init__(config)
  678. self.num_labels = config.num_labels
  679. self.config = config
  680. self.esm = EsmModel(config, add_pooling_layer=False)
  681. self.classifier = EsmClassificationHead(config)
  682. self.post_init()
  683. @can_return_tuple
  684. @auto_docstring
  685. def forward(
  686. self,
  687. input_ids: torch.LongTensor | None = None,
  688. attention_mask: torch.Tensor | None = None,
  689. position_ids: torch.LongTensor | None = None,
  690. inputs_embeds: torch.FloatTensor | None = None,
  691. labels: torch.LongTensor | None = None,
  692. **kwargs: Unpack[TransformersKwargs],
  693. ) -> tuple | SequenceClassifierOutput:
  694. r"""
  695. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  696. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  697. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  698. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  699. """
  700. outputs = self.esm(
  701. input_ids,
  702. attention_mask=attention_mask,
  703. position_ids=position_ids,
  704. inputs_embeds=inputs_embeds,
  705. **kwargs,
  706. )
  707. sequence_output = outputs[0]
  708. logits = self.classifier(sequence_output)
  709. loss = None
  710. if labels is not None:
  711. labels = labels.to(logits.device)
  712. if self.config.problem_type is None:
  713. if self.num_labels == 1:
  714. self.config.problem_type = "regression"
  715. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  716. self.config.problem_type = "single_label_classification"
  717. else:
  718. self.config.problem_type = "multi_label_classification"
  719. if self.config.problem_type == "regression":
  720. loss_fct = MSELoss()
  721. if self.num_labels == 1:
  722. loss = loss_fct(logits.squeeze(), labels.squeeze())
  723. else:
  724. loss = loss_fct(logits, labels)
  725. elif self.config.problem_type == "single_label_classification":
  726. loss_fct = CrossEntropyLoss()
  727. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  728. elif self.config.problem_type == "multi_label_classification":
  729. loss_fct = BCEWithLogitsLoss()
  730. loss = loss_fct(logits, labels)
  731. return SequenceClassifierOutput(
  732. loss=loss,
  733. logits=logits,
  734. hidden_states=outputs.hidden_states,
  735. attentions=outputs.attentions,
  736. )
  737. @auto_docstring
  738. class EsmForTokenClassification(EsmPreTrainedModel):
  739. def __init__(self, config):
  740. super().__init__(config)
  741. self.num_labels = config.num_labels
  742. self.esm = EsmModel(config, add_pooling_layer=False)
  743. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  744. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  745. self.post_init()
  746. @can_return_tuple
  747. @auto_docstring
  748. def forward(
  749. self,
  750. input_ids: torch.LongTensor | None = None,
  751. attention_mask: torch.Tensor | None = None,
  752. position_ids: torch.LongTensor | None = None,
  753. inputs_embeds: torch.FloatTensor | None = None,
  754. labels: torch.LongTensor | None = None,
  755. **kwargs: Unpack[TransformersKwargs],
  756. ) -> tuple | TokenClassifierOutput:
  757. r"""
  758. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  759. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  760. """
  761. outputs = self.esm(
  762. input_ids,
  763. attention_mask=attention_mask,
  764. position_ids=position_ids,
  765. inputs_embeds=inputs_embeds,
  766. **kwargs,
  767. )
  768. sequence_output = outputs[0]
  769. sequence_output = self.dropout(sequence_output)
  770. logits = self.classifier(sequence_output)
  771. loss = None
  772. if labels is not None:
  773. loss_fct = CrossEntropyLoss()
  774. labels = labels.to(logits.device)
  775. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  776. return TokenClassifierOutput(
  777. loss=loss,
  778. logits=logits,
  779. hidden_states=outputs.hidden_states,
  780. attentions=outputs.attentions,
  781. )
  782. class EsmClassificationHead(nn.Module):
  783. """Head for sentence-level classification tasks."""
  784. def __init__(self, config):
  785. super().__init__()
  786. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  787. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  788. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  789. def forward(self, features, **kwargs):
  790. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  791. x = self.dropout(x)
  792. x = self.dense(x)
  793. x = torch.tanh(x)
  794. x = self.dropout(x)
  795. x = self.out_proj(x)
  796. return x
  797. def create_position_ids_from_input_ids(input_ids, padding_idx):
  798. """
  799. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  800. are ignored. This is modified from fairseq's `utils.make_positions`.
  801. Args:
  802. x: torch.Tensor x:
  803. Returns: torch.Tensor
  804. """
  805. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  806. mask = input_ids.ne(padding_idx).int()
  807. incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
  808. return incremental_indices.long() + padding_idx
  809. __all__ = [
  810. "EsmForMaskedLM",
  811. "EsmForSequenceClassification",
  812. "EsmForTokenClassification",
  813. "EsmModel",
  814. "EsmPreTrainedModel",
  815. ]