modular_evolla.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952
  1. # Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from dataclasses import dataclass
  15. import torch
  16. from torch import nn
  17. from ... import initialization as init
  18. from ...cache_utils import Cache, DynamicCache
  19. from ...generation import GenerationMixin
  20. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  21. from ...modeling_outputs import (
  22. BaseModelOutputWithPast,
  23. BaseModelOutputWithPoolingAndCrossAttentions,
  24. CausalLMOutputWithPast,
  25. ModelOutput,
  26. )
  27. from ...modeling_utils import PreTrainedModel
  28. from ...utils import (
  29. auto_docstring,
  30. can_return_tuple,
  31. logging,
  32. )
  33. from ...utils.generic import merge_with_config_defaults
  34. from ...utils.output_capturing import OutputRecorder, capture_outputs
  35. from ..esm.modeling_esm import (
  36. EsmAttention,
  37. EsmEmbeddings,
  38. EsmEncoder,
  39. EsmIntermediate,
  40. EsmLayer,
  41. EsmOutput,
  42. EsmPooler,
  43. EsmSelfAttention,
  44. EsmSelfOutput,
  45. )
  46. from ..llama.modeling_llama import (
  47. LlamaAttention,
  48. LlamaDecoderLayer,
  49. LlamaMLP,
  50. LlamaPreTrainedModel,
  51. LlamaRMSNorm,
  52. LlamaRotaryEmbedding,
  53. )
  54. from .configuration_evolla import EvollaConfig, SaProtConfig
  55. logger = logging.get_logger(__name__)
  56. class EvollaSaProtEmbeddings(EsmEmbeddings):
  57. def __init__(self, config):
  58. super().__init__(config)
  59. # remove the position_ids in EsmEmbeddings
  60. self.position_ids = None
  61. def rotate_half_esm(x):
  62. x1, x2 = x.chunk(2, dim=-1)
  63. return torch.cat((-x2, x1), dim=-1)
  64. def apply_rotary_pos_emb_esm(x, cos, sin):
  65. cos = cos[:, :, : x.shape[-2], :]
  66. sin = sin[:, :, : x.shape[-2], :]
  67. return (x * cos) + (rotate_half_esm(x) * sin)
  68. class EvollaSaProtRotaryEmbedding(nn.Module):
  69. """
  70. Rotary position embeddings based on those in
  71. [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
  72. matrices which depend on their relative positions.
  73. """
  74. inv_freq: torch.Tensor # fix linting for `register_buffer`
  75. def __init__(self, dim: int):
  76. super().__init__()
  77. self.dim = dim
  78. # Generate and save the inverse frequency buffer (non trainable)
  79. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
  80. self.register_buffer("inv_freq", inv_freq)
  81. self._seq_len_cached = None
  82. self._cos_cached = None
  83. self._sin_cached = None
  84. def _update_cos_sin_tables(self, x, seq_dimension=2):
  85. seq_len = x.shape[seq_dimension]
  86. # Reset the tables if the sequence length has changed,
  87. # or if we're on a new device (possibly due to tracing for instance)
  88. if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
  89. self._seq_len_cached = seq_len
  90. t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
  91. freqs = torch.outer(t, self.inv_freq)
  92. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  93. self._cos_cached = emb.cos()[None, None, :, :]
  94. self._sin_cached = emb.sin()[None, None, :, :]
  95. return self._cos_cached, self._sin_cached
  96. def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  97. self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
  98. return (
  99. apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
  100. apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
  101. )
  102. class EvollaSaProtSelfAttention(EsmSelfAttention):
  103. def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
  104. nn.Module.__init__(self)
  105. self.config = config
  106. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  107. raise ValueError(
  108. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  109. f"heads ({config.num_attention_heads})"
  110. )
  111. self.num_attention_heads = config.num_attention_heads
  112. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  113. self.all_head_size = self.num_attention_heads * self.attention_head_size
  114. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  115. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  116. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  117. self.dropout = config.attention_probs_dropout_prob
  118. self.rotary_embeddings = None
  119. self.position_embedding_type = position_embedding_type or getattr(
  120. config, "position_embedding_type", "absolute"
  121. )
  122. if self.position_embedding_type == "rotary":
  123. self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size)
  124. self.is_decoder = config.is_decoder
  125. self.layer_idx = layer_idx
  126. self.scaling = 1.0
  127. self.is_causal = self.is_decoder and not is_cross_attention
  128. class EvollaSaProtSelfOutput(EsmSelfOutput):
  129. pass
  130. class EvollaSaProtAttention(EsmAttention):
  131. pass
  132. class EvollaSaProtIntermediate(EsmIntermediate):
  133. pass
  134. class EvollaSaProtOutput(EsmOutput):
  135. pass
  136. class EvollaSaProtLayer(EsmLayer):
  137. pass
  138. class EvollaSaProtEncoder(EsmEncoder):
  139. pass
  140. class EvollaSaProtPooler(EsmPooler):
  141. pass
  142. @auto_docstring
  143. class EvollaSaProtPreTrainedModel(PreTrainedModel):
  144. config: SaProtConfig
  145. _no_split_modules = ["EvollaSaProtLayer"]
  146. _supports_flash_attn = True
  147. _supports_sdpa = True
  148. _supports_flex_attn = True
  149. _supports_attention_backend = True
  150. _can_record_outputs = {
  151. "hidden_states": EvollaSaProtLayer,
  152. "attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")],
  153. "cross_attentions": [
  154. OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"),
  155. ],
  156. }
  157. def _init_weights(self, module):
  158. super()._init_weights(module)
  159. if isinstance(module, EvollaSaProtRotaryEmbedding):
  160. inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
  161. init.copy_(module.inv_freq, inv_freq)
  162. class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
  163. def __init__(self, config: SaProtConfig):
  164. super().__init__(config)
  165. self.embeddings = EvollaSaProtEmbeddings(config)
  166. self.encoder = EvollaSaProtEncoder(config)
  167. self.post_init()
  168. def get_input_embeddings(self):
  169. return self.embeddings.word_embeddings
  170. def set_input_embeddings(self, value):
  171. self.embeddings.word_embeddings = value
  172. @merge_with_config_defaults
  173. @capture_outputs
  174. def forward(
  175. self,
  176. input_ids: torch.Tensor | None,
  177. attention_mask: torch.Tensor | None = None,
  178. **kwargs,
  179. ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  180. input_shape = input_ids.size()
  181. batch_size, seq_length = input_shape
  182. device = input_ids.device
  183. if attention_mask is None:
  184. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  185. inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask)
  186. attention_mask = create_bidirectional_mask(
  187. config=self.config,
  188. inputs_embeds=inputs_embeds,
  189. attention_mask=attention_mask,
  190. )
  191. encoder_outputs = self.encoder(inputs_embeds, attention_mask=attention_mask, **kwargs)
  192. sequence_output = encoder_outputs[0]
  193. return BaseModelOutputWithPoolingAndCrossAttentions(
  194. last_hidden_state=sequence_output,
  195. hidden_states=encoder_outputs.hidden_states,
  196. attentions=encoder_outputs.attentions,
  197. cross_attentions=encoder_outputs.cross_attentions,
  198. )
  199. class EvollaSequenceCompressorAttention(nn.Module):
  200. def __init__(self, dim, dim_head=64, heads=8):
  201. super().__init__()
  202. self.scale = dim_head**-0.5
  203. self.heads = heads
  204. inner_dim = dim_head * heads
  205. self.norm_media = nn.LayerNorm(dim)
  206. self.norm_latents = nn.LayerNorm(dim)
  207. self.to_q = nn.Linear(dim, inner_dim, bias=False)
  208. self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
  209. self.to_out = nn.Linear(inner_dim, dim, bias=False)
  210. def forward(self, x, latents, mask):
  211. """
  212. Args:
  213. x (torch.Tensor): image features
  214. shape (b, n1, D)
  215. latent (torch.Tensor): latent features
  216. shape (b, n2, D); n2: num of latent tokens
  217. """
  218. x = self.norm_media(x)
  219. latents = self.norm_latents(latents)
  220. h = self.heads
  221. q = self.to_q(latents)
  222. kv_input = torch.cat((x, latents), dim=-2)
  223. k, v = self.to_kv(kv_input).chunk(
  224. 2, dim=-1
  225. ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads
  226. q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3)
  227. k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3)
  228. v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3)
  229. q = q * self.scale # batch_size, num_heads, num_latents, dim_head
  230. # attention
  231. sim = torch.matmul(q, k.transpose(-1, -2))
  232. sim = sim - sim.amax(dim=-1, keepdim=True).detach()
  233. bs, nh, skd, okd = sim.shape
  234. ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd)
  235. mask_exp = mask[:, None, None, :]
  236. ones_exp = ones[None, :, :, None]
  237. mask = mask_exp * ones_exp
  238. sim = sim.masked_fill((1 - mask).bool(), -1e4)
  239. attn = sim.softmax(dim=-1)
  240. out = torch.matmul(attn, v)
  241. out = out.permute(0, 2, 1, 3)
  242. # [batch, seq, head, features] -> [batch, seq, head*features]
  243. out = out.reshape(out.size(0), out.size(1), -1)
  244. return self.to_out(out)
  245. class EvollaFeedForward(nn.Module):
  246. def __init__(self, dim, mult=4):
  247. super().__init__()
  248. inner_dim = int(dim * mult)
  249. self.norm = nn.LayerNorm(dim)
  250. self.fc1 = nn.Linear(dim, inner_dim, bias=False)
  251. self.activation = nn.GELU()
  252. self.fc2 = nn.Linear(inner_dim, dim, bias=False)
  253. def forward(self, x):
  254. return self.fc2(self.activation(self.fc1(self.norm(x))))
  255. class EvollaSequenceCompressorResampler(nn.Module):
  256. def __init__(self, config: EvollaConfig):
  257. super().__init__()
  258. protein_repr_dim = config.protein_encoder_config.hidden_size
  259. self.num_latents = config.resampler_num_latents
  260. self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True)
  261. self.layers = nn.ModuleList([])
  262. for _ in range(config.resampler_depth):
  263. self.layers.append(
  264. nn.ModuleList(
  265. [
  266. EvollaSequenceCompressorAttention(
  267. dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads
  268. ),
  269. EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult),
  270. ]
  271. )
  272. )
  273. self.norm = nn.LayerNorm(config.hidden_size)
  274. self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size)
  275. def forward(self, embeds, mask):
  276. b = embeds.shape[0]
  277. bs, _ = mask.shape # bs, max_protein_length
  278. latent_mask = torch.ones(bs, self.num_latents).to(mask.device)
  279. mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents
  280. # blocks
  281. ones = torch.ones(b).to(self.latents.device)
  282. latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d]
  283. latents = latents.to(embeds.dtype)
  284. for attn, ff in self.layers:
  285. latents = attn(embeds, latents, mask) + latents
  286. latents = ff(latents) + latents
  287. transformed_feature = self.protein_projector(latents)
  288. return self.norm(transformed_feature)
  289. @dataclass
  290. @auto_docstring
  291. class EvollaProteinEncoderModelOutput(ModelOutput):
  292. sequence_compressor_output: torch.FloatTensor | None = None
  293. last_hidden_state: torch.FloatTensor | None = None
  294. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  295. attentions: tuple[torch.FloatTensor, ...] | None = None
  296. class EvollaProteinEncoder(nn.Module):
  297. def __init__(self, config: EvollaConfig):
  298. super().__init__()
  299. self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config)
  300. self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config)
  301. @can_return_tuple
  302. def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs):
  303. protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask)
  304. protein_embeds = protein_output.last_hidden_state
  305. sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask)
  306. return EvollaProteinEncoderModelOutput(
  307. sequence_compressor_output=sequence_repr,
  308. last_hidden_state=protein_output.last_hidden_state,
  309. )
  310. class EvollaSequenceAlignerCrossAttention(nn.Module):
  311. def __init__(
  312. self,
  313. config,
  314. protein_encoder_dim: int | None = None,
  315. structure_encoder_dim: int | None = None,
  316. msa_encoder_dim: int | None = None,
  317. ):
  318. super().__init__()
  319. self.hidden_size = config.hidden_size
  320. self.num_attention_heads = config.num_attention_heads
  321. self.scale = self.num_attention_heads**-0.5
  322. self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
  323. self.all_head_size = self.num_attention_heads * self.attention_head_size
  324. attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob
  325. enable_bias = config.aligner_enable_bias
  326. ffn_mult = config.aligner_ffn_mult
  327. self.query = nn.Linear(self.hidden_size, self.all_head_size)
  328. if protein_encoder_dim is not None:
  329. self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
  330. self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
  331. else:
  332. self.key_protein = None
  333. self.value_protein = None
  334. if structure_encoder_dim is not None:
  335. self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
  336. self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
  337. else:
  338. self.key_structure = None
  339. self.value_structure = None
  340. if msa_encoder_dim is not None:
  341. self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
  342. self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
  343. else:
  344. self.key_msa = None
  345. self.value_msa = None
  346. self.attention_norm = EvollaRMSNorm(self.hidden_size)
  347. self.dropout = nn.Dropout(attention_probs_dropout_prob)
  348. self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias)
  349. self.ff = EvollaFeedForward(self.hidden_size, ffn_mult)
  350. self.gate_attention = nn.Parameter(torch.tensor([0.0]))
  351. self.gate_ffw = nn.Parameter(torch.tensor([0.0]))
  352. def cross_attention(
  353. self,
  354. query_states,
  355. protein_key_value_states,
  356. structure_key_value_states,
  357. msa_key_value_states,
  358. query_attn_mask,
  359. protein_kv_attn_mask,
  360. structure_kv_attn_mask,
  361. msa_kv_attn_mask,
  362. ):
  363. """
  364. query_states: text
  365. key_value_states: protein
  366. query_states: [bs, query_seq_len, dim]
  367. key_value_states: [bs, kv_seq_len, dim]
  368. query_attn_mask: [bs, query_seq_len]
  369. kv_attn_mask: [bs, kv_seq_len]
  370. """
  371. # Concatenate protein and structure
  372. kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask]
  373. kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None]
  374. if not kv_attn_mask:
  375. raise ValueError("At least one modality should be provided for cross attention.")
  376. kv_attn_mask = torch.cat(kv_attn_mask, dim=1)
  377. query_layer = self.attention_norm(query_states)
  378. # Warning: This place might cause issues, refers to
  379. # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13
  380. # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable
  381. # Apply linear transformation to input_query, input_key, and input_value
  382. query_layer = self.query(query_layer) # [bs, querylength, dim]
  383. if self.key_protein is not None and self.value_protein is not None:
  384. protein_key_value_states = protein_key_value_states.to(query_states)
  385. key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim]
  386. value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim]
  387. else:
  388. key_layer_protein = None
  389. value_layer_protein = None
  390. if self.key_structure is not None and self.value_structure is not None:
  391. structure_key_value_states = structure_key_value_states.to(query_states)
  392. key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim]
  393. value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim]
  394. else:
  395. key_layer_structure = None
  396. value_layer_structure = None
  397. if self.key_msa is not None and self.value_msa is not None:
  398. msa_key_value_states = msa_key_value_states.to(query_states)
  399. key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim]
  400. value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim]
  401. else:
  402. key_layer_msa = None
  403. value_layer_msa = None
  404. key_layer = [key_layer_protein, key_layer_structure, key_layer_msa]
  405. key_layer = [_ for _ in key_layer if _ is not None]
  406. key_layer = torch.cat(key_layer, dim=1)
  407. value_layer = [value_layer_protein, value_layer_structure, value_layer_msa]
  408. value_layer = [_ for _ in value_layer if _ is not None]
  409. value_layer = torch.cat(value_layer, dim=1)
  410. new_query_layer_shape = query_layer.size()[:-1] + (
  411. self.num_attention_heads,
  412. self.attention_head_size,
  413. )
  414. query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3)
  415. new_key_layer_shape = key_layer.size()[:-1] + (
  416. self.num_attention_heads,
  417. self.attention_head_size,
  418. )
  419. key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3)
  420. new_value_layer_shape = value_layer.size()[:-1] + (
  421. self.num_attention_heads,
  422. self.attention_head_size,
  423. )
  424. value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3)
  425. query_layer = query_layer * self.scale
  426. # attention_mask: [bs, 1, querylength, keylength]
  427. if query_attn_mask is None:
  428. query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device)
  429. attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :]
  430. # Compute the scaled dot-product attention scores
  431. attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength]
  432. attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stabilize score
  433. attention_scores = attn_weights.masked_fill(
  434. (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min
  435. ) # [bs, numheads, querylength, keylength]
  436. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  437. # attention_probs_dropped = self.dropout(attention_probs)
  438. context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads]
  439. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  440. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  441. context_layer = context_layer.view(*new_context_layer_shape)
  442. context_layer = self.out_proj(context_layer)
  443. return context_layer
  444. def forward(
  445. self,
  446. query_states,
  447. protein_kv_states,
  448. structure_kv_states,
  449. msa_kv_states,
  450. query_attn_mask,
  451. protein_kv_attn_mask=None,
  452. structure_kv_attn_mask=None,
  453. msa_kv_attn_mask=None,
  454. protein_batch_mask=None,
  455. structure_batch_mask=None,
  456. msa_batch_mask=None,
  457. past_key_values=None,
  458. ):
  459. if protein_kv_states is not None:
  460. bs, protein_kv_seq_len, dim = protein_kv_states.shape
  461. if protein_kv_attn_mask is None:
  462. protein_kv_attn_mask = (
  463. torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device)
  464. * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T
  465. ).to(protein_kv_states.device)
  466. else:
  467. protein_kv_attn_mask = None
  468. if structure_kv_states is not None:
  469. bs, structure_kv_seq_len, dim = structure_kv_states.shape
  470. if structure_kv_attn_mask is None:
  471. structure_kv_attn_mask = (
  472. torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device)
  473. * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T
  474. ).to(structure_kv_states.device)
  475. else:
  476. structure_kv_attn_mask = None
  477. if msa_kv_states is not None:
  478. bs, msa_kv_seq_len, dim = msa_kv_states.shape
  479. if msa_kv_attn_mask is None:
  480. msa_kv_attn_mask = (
  481. torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device)
  482. * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T
  483. ).to(msa_kv_states.device)
  484. else:
  485. msa_kv_attn_mask = None
  486. hidden_states = query_states
  487. # only when there's at least one valid modality, crossattention will be performed
  488. if (
  489. (protein_kv_states is not None and protein_kv_attn_mask.any())
  490. or (structure_kv_states is not None and structure_kv_attn_mask.any())
  491. or (msa_kv_states is not None and msa_kv_attn_mask.any())
  492. ):
  493. residual = hidden_states
  494. hidden_states = self.cross_attention(
  495. query_states=hidden_states,
  496. protein_key_value_states=protein_kv_states,
  497. structure_key_value_states=structure_kv_states,
  498. msa_key_value_states=msa_kv_states,
  499. query_attn_mask=query_attn_mask,
  500. protein_kv_attn_mask=protein_kv_attn_mask,
  501. structure_kv_attn_mask=structure_kv_attn_mask,
  502. msa_kv_attn_mask=msa_kv_attn_mask,
  503. ) # [bs, query_seq_len, dim]
  504. # tanh gate
  505. hidden_states = torch.tanh(self.gate_attention) * hidden_states
  506. hidden_states = residual + hidden_states # input_query
  507. residual = hidden_states
  508. hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw)
  509. hidden_states = residual + hidden_states
  510. return hidden_states
  511. class EvollaRMSNorm(LlamaRMSNorm):
  512. pass
  513. class EvollaRotaryEmbedding(LlamaRotaryEmbedding):
  514. pass
  515. class EvollaMLP(LlamaMLP):
  516. pass
  517. class EvollaAttention(LlamaAttention):
  518. pass
  519. class EvollaDecoderLayer(LlamaDecoderLayer):
  520. def __init__(self, config: EvollaConfig, layer_idx: int):
  521. super().__init__(config, layer_idx)
  522. if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0:
  523. self.adapter = EvollaSequenceAlignerCrossAttention(
  524. config,
  525. protein_encoder_dim=config.hidden_size,
  526. )
  527. def forward(
  528. self,
  529. hidden_states: torch.Tensor,
  530. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  531. attention_mask: torch.Tensor | None = None,
  532. position_ids: torch.LongTensor | None = None,
  533. past_key_values: Cache | None = None,
  534. use_cache: bool | None = False,
  535. protein_kv_states: torch.Tensor | None = None,
  536. structure_kv_states: torch.Tensor | None = None,
  537. msa_kv_states: torch.Tensor | None = None,
  538. protein_batch_mask: torch.Tensor | None = None,
  539. structure_batch_mask: torch.Tensor | None = None,
  540. msa_batch_mask: torch.Tensor | None = None,
  541. query_attn_mask: torch.Tensor | None = None,
  542. **kwargs,
  543. ):
  544. residual = hidden_states
  545. hidden_states = self.input_layernorm(hidden_states)
  546. # Self Attention
  547. hidden_states, _ = self.self_attn(
  548. hidden_states=hidden_states,
  549. attention_mask=attention_mask,
  550. position_ids=position_ids,
  551. past_key_values=past_key_values,
  552. use_cache=use_cache,
  553. position_embeddings=position_embeddings,
  554. **kwargs,
  555. )
  556. hidden_states = residual + hidden_states
  557. # Fully Connected
  558. residual = hidden_states
  559. hidden_states = self.post_attention_layernorm(hidden_states)
  560. hidden_states = self.mlp(hidden_states)
  561. hidden_states = residual + hidden_states
  562. if hasattr(self, "adapter"):
  563. hidden_states = self.adapter(
  564. query_states=hidden_states,
  565. protein_kv_states=protein_kv_states,
  566. structure_kv_states=structure_kv_states,
  567. msa_kv_states=msa_kv_states,
  568. query_attn_mask=query_attn_mask,
  569. protein_batch_mask=protein_batch_mask,
  570. structure_batch_mask=structure_batch_mask,
  571. msa_batch_mask=msa_batch_mask,
  572. )
  573. return hidden_states
  574. class EvollaPreTrainedModel(LlamaPreTrainedModel):
  575. _supports_flash_attn = False # see dependency on `EvollaSequenceCompressorResampler`
  576. _supports_flex_attn = False # see dependency on `EvollaSequenceCompressorResampler`
  577. _supports_attention_backend = False
  578. _no_split_modules = [
  579. "EvollaDecoderLayer",
  580. "EvollaSequenceCompressorResampler",
  581. "EvollaSequenceAlignerCrossAttention",
  582. ]
  583. @torch.no_grad()
  584. def _init_weights(self, module):
  585. std = self.config.initializer_range
  586. PreTrainedModel._init_weights(self, module)
  587. if isinstance(module, EvollaSequenceAlignerCrossAttention):
  588. init.zeros_(module.gate_attention)
  589. init.zeros_(module.gate_ffw)
  590. init.ones_(module.attention_norm.weight)
  591. elif isinstance(module, EvollaSequenceCompressorResampler):
  592. init.normal_(module.latents, mean=0.0, std=std)
  593. class EvollaModel(EvollaPreTrainedModel):
  594. def __init__(self, config: EvollaConfig):
  595. super().__init__(config)
  596. self.padding_idx = config.pad_token_id
  597. self.vocab_size = config.vocab_size
  598. self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
  599. self.protein_encoder = EvollaProteinEncoder(config=config)
  600. self.layers = nn.ModuleList(
  601. [
  602. EvollaDecoderLayer(
  603. config=config,
  604. layer_idx=layer_idx,
  605. )
  606. for layer_idx in range(config.num_hidden_layers)
  607. ]
  608. )
  609. self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  610. self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
  611. self.rotary_emb = EvollaRotaryEmbedding(config=config)
  612. self.post_init()
  613. def get_input_embeddings(self):
  614. return self.embed_tokens
  615. def set_input_embeddings(self, value):
  616. self.embed_tokens = value
  617. @auto_docstring
  618. @merge_with_config_defaults
  619. @capture_outputs
  620. def forward(
  621. self,
  622. input_ids: torch.LongTensor | None = None,
  623. attention_mask: torch.Tensor | None = None,
  624. position_ids: torch.LongTensor | None = None,
  625. past_key_values: Cache | None = None,
  626. inputs_embeds: torch.FloatTensor | None = None,
  627. use_cache: bool | None = None,
  628. protein_input_ids: torch.LongTensor | None = None,
  629. protein_attention_mask: torch.Tensor | None = None,
  630. structure_feats: torch.FloatTensor | None = None,
  631. msa_feats: torch.FloatTensor | None = None,
  632. structure_batch_mask: torch.Tensor | None = None,
  633. msa_batch_mask: torch.Tensor | None = None,
  634. **kwargs,
  635. ) -> tuple | BaseModelOutputWithPast:
  636. r"""
  637. protein_input_ids (torch.LongTensor):
  638. The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
  639. protein_attention_mask (torch.Tensor):
  640. The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
  641. structure_feats (torch.FloatTensor):
  642. The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
  643. msa_feats (torch.FloatTensor):
  644. The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
  645. structure_batch_mask (torch.Tensor):
  646. The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now.
  647. msa_batch_mask (torch.Tensor):
  648. The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
  649. """
  650. if (input_ids is None) ^ (inputs_embeds is not None):
  651. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  652. if inputs_embeds is None:
  653. inputs_embeds = self.embed_tokens(input_ids)
  654. if use_cache and past_key_values is None:
  655. past_key_values = DynamicCache(config=self.config)
  656. if position_ids is None:
  657. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  658. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  659. position_ids = position_ids.unsqueeze(0)
  660. protein_feats = None
  661. protein_batch_mask = None
  662. # If provided, actually compute them
  663. if protein_input_ids is not None and protein_attention_mask is not None:
  664. protein_outputs = self.protein_encoder(
  665. input_ids=protein_input_ids,
  666. attention_mask=protein_attention_mask,
  667. )
  668. protein_feats = protein_outputs.sequence_compressor_output
  669. protein_batch_mask = torch.ones(
  670. protein_input_ids.shape[0],
  671. device=protein_input_ids.device,
  672. dtype=torch.bool,
  673. )
  674. causal_mask = create_causal_mask(
  675. config=self.config,
  676. inputs_embeds=inputs_embeds,
  677. attention_mask=attention_mask,
  678. past_key_values=past_key_values,
  679. )
  680. hidden_states = inputs_embeds
  681. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  682. for decoder_layer in self.layers:
  683. hidden_states = decoder_layer(
  684. hidden_states,
  685. attention_mask=causal_mask,
  686. position_ids=position_ids,
  687. past_key_values=past_key_values,
  688. use_cache=use_cache,
  689. protein_kv_states=protein_feats,
  690. structure_kv_states=structure_feats,
  691. msa_kv_states=msa_feats,
  692. protein_batch_mask=protein_batch_mask,
  693. structure_batch_mask=structure_batch_mask,
  694. msa_batch_mask=msa_batch_mask,
  695. query_attn_mask=attention_mask,
  696. position_embeddings=position_embeddings,
  697. **kwargs,
  698. )
  699. hidden_states = self.norm(hidden_states)
  700. output = BaseModelOutputWithPast(
  701. last_hidden_state=hidden_states,
  702. past_key_values=past_key_values,
  703. )
  704. return output
  705. class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin):
  706. def __init__(self, config):
  707. super().__init__(config)
  708. self.model = EvollaModel(config)
  709. self.vocab_size = config.vocab_size
  710. self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
  711. self.post_init()
  712. def get_input_embeddings(self):
  713. return self.model.get_input_embeddings()
  714. def set_input_embeddings(self, value):
  715. return self.model.set_input_embeddings(value)
  716. @can_return_tuple
  717. @auto_docstring
  718. def forward(
  719. self,
  720. input_ids: torch.LongTensor | None = None, # text input ids
  721. attention_mask: torch.Tensor | None = None, # text attention mask
  722. inputs_embeds: torch.FloatTensor | None = None, # text input embeddings
  723. labels: torch.LongTensor | None = None,
  724. protein_input_ids: torch.LongTensor | None = None,
  725. protein_attention_mask: torch.Tensor | None = None,
  726. use_cache: bool | None = None,
  727. logits_to_keep: int | torch.Tensor = 0,
  728. **kwargs,
  729. ):
  730. r"""
  731. protein_input_ids (torch.LongTensor):
  732. The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
  733. protein_attention_mask (torch.Tensor):
  734. The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
  735. Example:
  736. ```python
  737. >>> from transformers import EvollaProcessor, EvollaForProteinText2Text
  738. >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf")
  739. >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf")
  740. >>> protein_information = {
  741. "aa_seq": "your amino acid sequence",
  742. "foldseek": "your foldseek sequence",
  743. }
  744. >>> question = "What is the function of this protein?"
  745. >>> message = [
  746. {"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
  747. {"role": "user", "content": question},
  748. ]
  749. >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest")
  750. >>> outputs = model.generate(**inputs)
  751. >>> print(processor.batch_decode(outputs, skip_special_tokens=True))
  752. ```"""
  753. outputs: BaseModelOutputWithPast = self.model(
  754. input_ids=input_ids,
  755. attention_mask=attention_mask,
  756. inputs_embeds=inputs_embeds,
  757. protein_input_ids=protein_input_ids,
  758. protein_attention_mask=protein_attention_mask,
  759. use_cache=use_cache,
  760. **kwargs,
  761. )
  762. hidden_states = outputs.last_hidden_state
  763. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  764. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  765. logits = self.lm_head(hidden_states[:, slice_indices, :])
  766. loss = None
  767. if labels is not None:
  768. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
  769. lm_outputs = CausalLMOutputWithPast(
  770. loss=loss,
  771. logits=logits,
  772. past_key_values=outputs.past_key_values,
  773. hidden_states=outputs.hidden_states,
  774. attentions=outputs.attentions,
  775. )
  776. return lm_outputs
  777. __all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"]