modeling_esmfold.py 83 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293
  1. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import sys
  16. from collections.abc import Callable, Sequence
  17. from dataclasses import dataclass
  18. from functools import partial
  19. import numpy as np
  20. import torch
  21. import torch.nn as nn
  22. from torch.nn import LayerNorm
  23. from ... import initialization as init
  24. from ...integrations.deepspeed import is_deepspeed_available
  25. from ...modeling_outputs import ModelOutput
  26. from ...utils import (
  27. ContextManagers,
  28. auto_docstring,
  29. logging,
  30. )
  31. from ...utils.generic import maybe_autocast
  32. from .modeling_esm import EsmModel, EsmPreTrainedModel
  33. from .openfold_utils import (
  34. OFProtein,
  35. Rigid,
  36. Rotation,
  37. atom14_to_atom37,
  38. chunk_layer,
  39. compute_predicted_aligned_error,
  40. compute_tm,
  41. frames_and_literature_positions_to_atom14_pos,
  42. make_atom14_masks,
  43. residue_constants,
  44. to_pdb,
  45. torsion_angles_to_frames,
  46. )
  47. logger = logging.get_logger(__name__)
  48. @dataclass
  49. @auto_docstring(
  50. custom_intro="""
  51. Output type of [`EsmForProteinFoldingOutput`].
  52. """
  53. )
  54. class EsmForProteinFoldingOutput(ModelOutput):
  55. r"""
  56. frames (`torch.FloatTensor`):
  57. Output frames.
  58. sidechain_frames (`torch.FloatTensor`):
  59. Output sidechain frames.
  60. unnormalized_angles (`torch.FloatTensor`):
  61. Predicted unnormalized backbone and side chain torsion angles.
  62. angles (`torch.FloatTensor`):
  63. Predicted backbone and side chain torsion angles.
  64. positions (`torch.FloatTensor`):
  65. Predicted positions of the backbone and side chain atoms.
  66. states (`torch.FloatTensor`):
  67. Hidden states from the protein folding trunk.
  68. s_s (`torch.FloatTensor`):
  69. Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem.
  70. s_z (`torch.FloatTensor`):
  71. Pairwise residue embeddings.
  72. distogram_logits (`torch.FloatTensor`):
  73. Input logits to the distogram used to compute residue distances.
  74. lm_logits (`torch.FloatTensor`):
  75. Logits output by the ESM-2 protein language model stem.
  76. aatype (`torch.FloatTensor`):
  77. Input amino acids (AlphaFold2 indices).
  78. atom14_atom_exists (`torch.FloatTensor`):
  79. Whether each atom exists in the atom14 representation.
  80. residx_atom14_to_atom37 (`torch.FloatTensor`):
  81. Mapping between atoms in the atom14 and atom37 representations.
  82. residx_atom37_to_atom14 (`torch.FloatTensor`):
  83. Mapping between atoms in the atom37 and atom14 representations.
  84. atom37_atom_exists (`torch.FloatTensor`):
  85. Whether each atom exists in the atom37 representation.
  86. residue_index (`torch.FloatTensor`):
  87. The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be
  88. a sequence of integers from 0 to `sequence_length`.
  89. lddt_head (`torch.FloatTensor`):
  90. Raw outputs from the lddt head used to compute plddt.
  91. plddt (`torch.FloatTensor`):
  92. Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is
  93. uncertain, or where the protein structure is disordered.
  94. ptm_logits (`torch.FloatTensor`):
  95. Raw logits used for computing ptm.
  96. ptm (`torch.FloatTensor`):
  97. TM-score output representing the model's high-level confidence in the overall structure.
  98. aligned_confidence_probs (`torch.FloatTensor`):
  99. Per-residue confidence scores for the aligned structure.
  100. predicted_aligned_error (`torch.FloatTensor`):
  101. Predicted error between the model's prediction and the ground truth.
  102. max_predicted_aligned_error (`torch.FloatTensor`):
  103. Per-sample maximum predicted error.
  104. """
  105. frames: torch.FloatTensor | None = None
  106. sidechain_frames: torch.FloatTensor | None = None
  107. unnormalized_angles: torch.FloatTensor | None = None
  108. angles: torch.FloatTensor | None = None
  109. positions: torch.FloatTensor | None = None
  110. states: torch.FloatTensor | None = None
  111. s_s: torch.FloatTensor | None = None
  112. s_z: torch.FloatTensor | None = None
  113. distogram_logits: torch.FloatTensor | None = None
  114. lm_logits: torch.FloatTensor | None = None
  115. aatype: torch.FloatTensor | None = None
  116. atom14_atom_exists: torch.FloatTensor | None = None
  117. residx_atom14_to_atom37: torch.FloatTensor | None = None
  118. residx_atom37_to_atom14: torch.FloatTensor | None = None
  119. atom37_atom_exists: torch.FloatTensor | None = None
  120. residue_index: torch.FloatTensor | None = None
  121. lddt_head: torch.FloatTensor | None = None
  122. plddt: torch.FloatTensor | None = None
  123. ptm_logits: torch.FloatTensor | None = None
  124. ptm: torch.FloatTensor | None = None
  125. aligned_confidence_probs: torch.FloatTensor | None = None
  126. predicted_aligned_error: torch.FloatTensor | None = None
  127. max_predicted_aligned_error: torch.FloatTensor | None = None
  128. def is_fp16_enabled(device_type):
  129. # Autocast world
  130. autocast_dtype = torch.get_autocast_dtype(device_type)
  131. fp16_enabled = autocast_dtype == torch.float16
  132. fp16_enabled = fp16_enabled and torch.is_autocast_enabled(device_type)
  133. return fp16_enabled
  134. def is_deepspeed_initialized():
  135. if is_deepspeed_available():
  136. return False
  137. else:
  138. try:
  139. import deepspeed
  140. # This is not available in all DeepSpeed versions.
  141. return deepspeed.utils.is_initialized()
  142. except Exception:
  143. return False
  144. def collate_dense_tensors(samples: list[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
  145. """
  146. Takes a list of tensors with the following dimensions:
  147. [(d_11, ..., d_1K),
  148. (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)]
  149. and stack + pads them into a single tensor of:
  150. (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
  151. """
  152. if len(samples) == 0:
  153. return torch.Tensor()
  154. if len({x.dim() for x in samples}) != 1:
  155. raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
  156. (device,) = tuple({x.device for x in samples}) # assumes all on same device
  157. max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
  158. result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
  159. result.fill_(pad_v)
  160. for i in range(len(samples)):
  161. result_i = result[i]
  162. t = samples[i]
  163. result_i[tuple(slice(0, k) for k in t.shape)] = t
  164. return result
  165. def flatten_final_dims(t: torch.Tensor, no_dims: int):
  166. return t.reshape(t.shape[:-no_dims] + (-1,))
  167. def permute_final_dims(tensor: torch.Tensor, inds: list[int]):
  168. zero_index = -1 * len(inds)
  169. first_inds = list(range(len(tensor.shape[:zero_index])))
  170. return tensor.permute(first_inds + [zero_index + i for i in inds])
  171. def dict_multimap(fn, dicts):
  172. first = dicts[0]
  173. new_dict = {}
  174. for k, v in first.items():
  175. all_v = [d[k] for d in dicts]
  176. if isinstance(v, dict):
  177. new_dict[k] = dict_multimap(fn, all_v)
  178. else:
  179. new_dict[k] = fn(all_v)
  180. return new_dict
  181. class EsmFoldLinear(nn.Linear):
  182. """
  183. A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
  184. Implements the initializers in 1.11.4, plus some additional ones found in the code.
  185. """
  186. def __init__(
  187. self,
  188. in_dim: int,
  189. out_dim: int,
  190. bias: bool = True,
  191. init: str = "default",
  192. init_fn: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
  193. ):
  194. """
  195. Args:
  196. in_dim:
  197. The final dimension of inputs to the layer
  198. out_dim:
  199. The final dimension of layer outputs
  200. bias:
  201. Whether to learn an additive bias. True by default
  202. init:
  203. The initializer to use. Choose from:
  204. "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal
  205. distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal":
  206. Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0
  207. Overridden by init_fn if the latter is not None.
  208. init_fn:
  209. A custom initializer taking weight and bias as inputs. Overrides init if not None.
  210. """
  211. super().__init__(in_dim, out_dim, bias=bias)
  212. if bias:
  213. with torch.no_grad():
  214. self.bias.fill_(0)
  215. self.init = init
  216. self.init_fn = init_fn
  217. if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
  218. raise ValueError("Invalid init string.")
  219. class EsmFoldLayerNorm(nn.Module):
  220. def __init__(self, c_in, eps=1e-5):
  221. super().__init__()
  222. self.c_in = (c_in,)
  223. self.eps = eps
  224. self.weight = nn.Parameter(torch.ones(c_in))
  225. self.bias = nn.Parameter(torch.zeros(c_in))
  226. def forward(self, x):
  227. d = x.dtype
  228. if d is torch.bfloat16 and not is_deepspeed_initialized():
  229. with maybe_autocast(device_type="cuda", enabled=False):
  230. out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
  231. else:
  232. out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
  233. return out
  234. @torch.jit.ignore
  235. def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
  236. """
  237. Softmax, but without automatic casting to fp32 when the input is of type bfloat16
  238. """
  239. d = t.dtype
  240. if d is torch.bfloat16 and not is_deepspeed_initialized():
  241. with maybe_autocast(device_type="cuda", enabled=False):
  242. s = torch.nn.functional.softmax(t, dim=dim)
  243. else:
  244. s = torch.nn.functional.softmax(t, dim=dim)
  245. return s
  246. class EsmFoldAttention(nn.Module):
  247. """
  248. Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
  249. """
  250. def __init__(
  251. self,
  252. c_q: int,
  253. c_k: int,
  254. c_v: int,
  255. c_hidden: int,
  256. no_heads: int,
  257. gating: bool = True,
  258. ):
  259. """
  260. Args:
  261. c_q:
  262. Input dimension of query data
  263. c_k:
  264. Input dimension of key data
  265. c_v:
  266. Input dimension of value data
  267. c_hidden:
  268. Per-head hidden dimension
  269. no_heads:
  270. Number of attention heads
  271. gating:
  272. Whether the output should be gated using query data
  273. """
  274. super().__init__()
  275. self.c_q = c_q
  276. self.c_k = c_k
  277. self.c_v = c_v
  278. self.c_hidden = c_hidden
  279. self.no_heads = no_heads
  280. self.gating = gating
  281. # DISCREPANCY: c_hidden is not the per-head channel dimension, as
  282. # stated in the supplement, but the overall channel dimension.
  283. self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
  284. self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
  285. self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
  286. self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")
  287. self.linear_g = None
  288. if self.gating:
  289. self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")
  290. self.sigmoid = nn.Sigmoid()
  291. def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  292. # [*, Q/K/V, H * C_hidden]
  293. q = self.linear_q(q_x)
  294. k = self.linear_k(kv_x)
  295. v = self.linear_v(kv_x)
  296. # [*, Q/K, H, C_hidden]
  297. q = q.view(q.shape[:-1] + (self.no_heads, -1))
  298. k = k.view(k.shape[:-1] + (self.no_heads, -1))
  299. v = v.view(v.shape[:-1] + (self.no_heads, -1))
  300. # [*, H, Q/K, C_hidden]
  301. q = q.transpose(-2, -3)
  302. k = k.transpose(-2, -3)
  303. v = v.transpose(-2, -3)
  304. q /= math.sqrt(self.c_hidden)
  305. return q, k, v
  306. def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
  307. if self.linear_g is not None:
  308. g = self.sigmoid(self.linear_g(q_x))
  309. # [*, Q, H, C_hidden]
  310. g = g.view(g.shape[:-1] + (self.no_heads, -1))
  311. o = o * g
  312. # [*, Q, H * C_hidden]
  313. o = flatten_final_dims(o, 2)
  314. # [*, Q, C_q]
  315. o = self.linear_o(o)
  316. return o
  317. def forward(
  318. self,
  319. q_x: torch.Tensor,
  320. kv_x: torch.Tensor,
  321. biases: list[torch.Tensor] | None = None,
  322. use_memory_efficient_kernel: bool = False,
  323. use_lma: bool = False,
  324. lma_q_chunk_size: int = 1024,
  325. lma_kv_chunk_size: int = 4096,
  326. use_flash: bool = False,
  327. flash_mask: torch.Tensor | None = None,
  328. ) -> torch.Tensor:
  329. """
  330. Args:
  331. q_x:
  332. [*, Q, C_q] query data
  333. kv_x:
  334. [*, K, C_k] key data
  335. biases:
  336. List of biases that broadcast to [*, H, Q, K]
  337. use_memory_efficient_kernel:
  338. Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
  339. If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
  340. use_lma:
  341. Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
  342. stock PyTorch implementation is used instead
  343. lma_q_chunk_size:
  344. Query chunk size (for LMA)
  345. lma_kv_chunk_size:
  346. Key/Value chunk size (for LMA)
  347. Returns
  348. [*, Q, C_q] attention update
  349. """
  350. if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
  351. raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
  352. if use_flash and biases is not None:
  353. raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
  354. attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
  355. if sum(attn_options) > 1:
  356. raise ValueError("Choose at most one alternative attention algorithm")
  357. if biases is None:
  358. biases = []
  359. # [*, H, Q/K, C_hidden]
  360. query, key, value = self._prep_qkv(q_x, kv_x)
  361. key = permute_final_dims(key, (1, 0))
  362. # [*, H, Q, K]
  363. output = torch.matmul(query, key)
  364. for b in biases:
  365. output += b
  366. output = softmax_no_cast(output, -1)
  367. # [*, H, Q, C_hidden]
  368. output = torch.matmul(output, value)
  369. output = output.transpose(-2, -3)
  370. output = self._wrap_up(output, q_x)
  371. return output
  372. class EsmFoldTriangleAttention(nn.Module):
  373. def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
  374. """
  375. Args:
  376. c_in:
  377. Input channel dimension
  378. c_hidden:
  379. Overall hidden channel dimension (not per-head)
  380. no_heads:
  381. Number of attention heads
  382. """
  383. super().__init__()
  384. self.c_in = c_in
  385. self.c_hidden = c_hidden
  386. self.no_heads = no_heads
  387. self.starting = starting
  388. self.inf = inf
  389. self.layer_norm = LayerNorm(self.c_in)
  390. self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
  391. self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
  392. @torch.jit.ignore
  393. def _chunk(
  394. self,
  395. x: torch.Tensor,
  396. biases: list[torch.Tensor],
  397. chunk_size: int,
  398. use_memory_efficient_kernel: bool = False,
  399. use_lma: bool = False,
  400. inplace_safe: bool = False,
  401. ) -> torch.Tensor:
  402. "triangle! triangle!"
  403. mha_inputs = {
  404. "q_x": x,
  405. "kv_x": x,
  406. "biases": biases,
  407. }
  408. return chunk_layer(
  409. partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
  410. mha_inputs,
  411. chunk_size=chunk_size,
  412. no_batch_dims=len(x.shape[:-2]),
  413. _out=x if inplace_safe else None,
  414. )
  415. def forward(
  416. self,
  417. x: torch.Tensor,
  418. mask: torch.Tensor | None = None,
  419. chunk_size: int | None = None,
  420. use_memory_efficient_kernel: bool = False,
  421. use_lma: bool = False,
  422. inplace_safe: bool = False,
  423. ) -> torch.Tensor:
  424. """
  425. Args:
  426. x:
  427. [*, I, J, C_in] input tensor (e.g. the pair representation)
  428. Returns:
  429. [*, I, J, C_in] output tensor
  430. """
  431. if mask is None:
  432. # [*, I, J]
  433. mask = x.new_ones(
  434. x.shape[:-1],
  435. )
  436. if not self.starting:
  437. x = x.transpose(-2, -3)
  438. mask = mask.transpose(-1, -2)
  439. # [*, I, J, C_in]
  440. x = self.layer_norm(x)
  441. # [*, I, 1, 1, J]
  442. mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
  443. # [*, H, I, J]
  444. triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
  445. # [*, 1, H, I, J]
  446. triangle_bias = triangle_bias.unsqueeze(-4)
  447. biases = [mask_bias, triangle_bias]
  448. if chunk_size is not None:
  449. x = self._chunk(
  450. x,
  451. biases,
  452. chunk_size,
  453. use_memory_efficient_kernel=use_memory_efficient_kernel,
  454. use_lma=use_lma,
  455. inplace_safe=inplace_safe,
  456. )
  457. else:
  458. x = self.mha(
  459. q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
  460. )
  461. if not self.starting:
  462. x = x.transpose(-2, -3)
  463. return x
  464. class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
  465. """
  466. Implements Algorithms 11 and 12.
  467. """
  468. def __init__(self, config, _outgoing=True):
  469. super().__init__()
  470. c_hidden = config.pairwise_state_dim
  471. self._outgoing = _outgoing
  472. self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
  473. self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
  474. self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
  475. self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
  476. self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
  477. self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
  478. self.layer_norm_in = LayerNorm(c_hidden)
  479. self.layer_norm_out = LayerNorm(c_hidden)
  480. self.sigmoid = nn.Sigmoid()
  481. def _combine_projections(
  482. self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: int | None = None
  483. ) -> torch.Tensor:
  484. if self._outgoing:
  485. a = permute_final_dims(a, (2, 0, 1))
  486. b = permute_final_dims(b, (2, 1, 0))
  487. else:
  488. a = permute_final_dims(a, (2, 1, 0))
  489. b = permute_final_dims(b, (2, 0, 1))
  490. if _inplace_chunk_size is not None:
  491. # To be replaced by torch vmap
  492. for i in range(0, a.shape[-3], _inplace_chunk_size):
  493. a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
  494. b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
  495. a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
  496. a_chunk,
  497. b_chunk,
  498. )
  499. p = a
  500. else:
  501. p = torch.matmul(a, b)
  502. return permute_final_dims(p, (1, 2, 0))
  503. def _inference_forward(
  504. self,
  505. z: torch.Tensor,
  506. mask: torch.Tensor | None = None,
  507. inplace_chunk_size: int | None = None,
  508. with_add: bool = True,
  509. ):
  510. """
  511. Args:
  512. z:
  513. A [*, N, N, C_z] pair representation
  514. mask:
  515. A [*, N, N] pair mask
  516. inplace_chunk_size:
  517. Size of chunks used in the main computation. Increase to trade memory for speed.
  518. with_add:
  519. If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update).
  520. Returns:
  521. A reference to the overwritten z
  522. More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the
  523. addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten
  524. values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size.
  525. Useful for inference on extremely long sequences.
  526. It works as follows. We will make reference to variables used in the default forward implementation below.
  527. Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the
  528. "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask,
  529. and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for
  530. N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate
  531. tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the
  532. tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over
  533. pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains
  534. inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring
  535. total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks
  536. directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at
  537. the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column
  538. ahead of previously overwritten columns and can be recovered directly from z. After the first iteration,
  539. however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache,
  540. a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For
  541. 0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith
  542. iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead.
  543. Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the
  544. z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache.
  545. After the final iteration, z has been completely overwritten and contains the triangular multiplicative update.
  546. If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case,
  547. peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small
  548. variables.
  549. """
  550. if mask is None:
  551. mask = z.new_ones(z.shape[:-1])
  552. mask = mask.unsqueeze(-1)
  553. def compute_projection_helper(pair, mask, a=True):
  554. if a:
  555. linear_g = self.linear_a_g
  556. linear_p = self.linear_a_p
  557. else:
  558. linear_g = self.linear_b_g
  559. linear_p = self.linear_b_p
  560. pair = self.layer_norm_in(pair)
  561. p = linear_g(pair)
  562. p.sigmoid_()
  563. p *= linear_p(pair)
  564. p *= mask
  565. p = permute_final_dims(p, (2, 0, 1))
  566. return p
  567. def compute_projection(pair, mask, a=True, chunked=True):
  568. need_transpose = self._outgoing ^ a
  569. if not chunked:
  570. p = compute_projection_helper(pair, mask, a)
  571. if need_transpose:
  572. p = p.transpose(-1, -2)
  573. else:
  574. # This computation is chunked so as not to exceed our 2.5x
  575. # budget with a large intermediate tensor
  576. linear_g = self.linear_a_g if a else self.linear_b_g
  577. c = linear_g.bias.shape[-1]
  578. out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
  579. p = pair.new_zeros(out_shape)
  580. for i in range(0, pair.shape[-3], inplace_chunk_size):
  581. pair_chunk = pair[..., i : i + inplace_chunk_size, :, :]
  582. pair_chunk = compute_projection_helper(
  583. pair[..., i : i + inplace_chunk_size, :, :],
  584. mask[..., i : i + inplace_chunk_size, :, :],
  585. a,
  586. )
  587. if need_transpose:
  588. pair_chunk = pair_chunk.transpose(-1, -2)
  589. p[..., i : i + inplace_chunk_size] = pair_chunk
  590. else:
  591. p[..., i : i + inplace_chunk_size, :] = pair_chunk
  592. del pair_chunk
  593. return p
  594. # We start by fully manifesting a. In addition to the input, this
  595. # brings total memory consumption to 2x z (disregarding size of chunks)
  596. # [*, N, N, c]
  597. a = compute_projection(z, mask, True, chunked=True)
  598. if inplace_chunk_size is not None:
  599. n = a.shape[-1]
  600. half_n = n // 2 + n % 2
  601. row_dim = -3
  602. col_dim = -2
  603. b_chunk_dim = row_dim if self._outgoing else col_dim
  604. def empty_slicer(t):
  605. return [slice(None) for _ in t.shape]
  606. def slice_tensor(t, start, end, dim):
  607. # Slices start:end from the dim dimension of t
  608. s = empty_slicer(t)
  609. s[dim] = slice(start, end)
  610. return t[s]
  611. def flip_z_cache_(z_cache, z):
  612. # "Reorient" the z_cache (see below), filling it with quadrants
  613. # 3---recovered from the z_cache---and 4---recovered from z---
  614. # of the input tensor z.
  615. quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim)
  616. z_cache = z_cache.transpose(row_dim, col_dim)
  617. # If n is odd, we need to shrink the z_cache by one row
  618. z_cache = z_cache[..., : (n // 2), :, :]
  619. # Move the 3rd quadrant of z into the
  620. first_half_slicer = empty_slicer(z_cache)
  621. first_half_slicer[col_dim] = slice(0, half_n)
  622. z_cache[first_half_slicer] = quadrant_3
  623. # Get the fourth quadrant of z
  624. quadrant_4 = slice_tensor(z, half_n, None, row_dim)
  625. quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim)
  626. # Insert said quadrant into the rotated z-cache
  627. quadrant_3_slicer = empty_slicer(z_cache)
  628. quadrant_3_slicer[col_dim] = slice(half_n, None)
  629. z_cache[quadrant_3_slicer] = quadrant_4
  630. return z_cache
  631. # Initialize the z cache to the left half of z.
  632. z_cache_shape = list(z.shape)
  633. z_cache_shape[col_dim] = half_n
  634. z_cache = z.new_zeros(z_cache_shape)
  635. z_cache_slicer = empty_slicer(z_cache)
  636. z_cache_slicer[col_dim] = slice(0, half_n)
  637. z_cache.copy_(z[z_cache_slicer])
  638. z_cache_rotated = False
  639. # We need to reorient the z-cache at the halfway point, and we
  640. # don't want a single chunk to straddle that point. We contract one
  641. # of the chunks in the middle to address that problem.
  642. i_range = list(range(0, half_n, inplace_chunk_size))
  643. initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])]
  644. after_half = list(range(half_n, n, inplace_chunk_size))
  645. after_half_offsets = [inplace_chunk_size for _ in after_half]
  646. combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets)
  647. for i, offset in combined_range_with_offsets:
  648. if not z_cache_rotated and i >= half_n:
  649. z_cache = flip_z_cache_(z_cache, z)
  650. z_cache_rotated = True
  651. z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim)
  652. mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim)
  653. z_chunk_b = z_chunk_b.clone()
  654. if b_chunk_dim == col_dim:
  655. z_chunk_b = slice_tensor(z, i, i + offset, col_dim)
  656. else: # b_chunk_dim == row_dim
  657. # In this case, the b-dimension (b_chunk_dim) is partially
  658. # overwritten at the end of each iteration. We need to
  659. # restore the missing component from the z-cache.
  660. if not z_cache_rotated:
  661. z_chunk_slicer = empty_slicer(z_chunk_b)
  662. z_chunk_slicer[col_dim] = slice(0, half_n)
  663. z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim)
  664. else:
  665. z_cache_offset = i - half_n
  666. z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim)
  667. b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False)
  668. del z_chunk_b
  669. x_chunk = torch.matmul(a, b_chunk)
  670. x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
  671. x_chunk = self.layer_norm_out(x_chunk)
  672. x_chunk = self.linear_z(x_chunk)
  673. # The g dimension (col_dim) is parallel to and ahead of the
  674. # overwrites in z. We can extract the g chunk normally.
  675. z_chunk_g = slice_tensor(z, i, i + offset, col_dim)
  676. g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
  677. g_chunk.sigmoid_()
  678. del z_chunk_g
  679. x_chunk *= g_chunk
  680. # Write the columns into z in-place
  681. z_slicer = empty_slicer(z)
  682. z_slicer[col_dim] = slice(i, i + offset)
  683. if with_add:
  684. z[z_slicer] += x_chunk
  685. else:
  686. z[z_slicer] = x_chunk
  687. else:
  688. b = compute_projection(z, mask, False, False)
  689. x = torch.matmul(a, b)
  690. x = self.layer_norm_out(x)
  691. x = self.linear_z(x)
  692. g = self.linear_g(z)
  693. g.sigmoid_()
  694. x *= g
  695. if with_add:
  696. z += x
  697. else:
  698. z = x
  699. return z
  700. def forward(
  701. self,
  702. z: torch.Tensor,
  703. mask: torch.Tensor | None = None,
  704. inplace_safe: bool = False,
  705. _add_with_inplace: bool = False,
  706. _inplace_chunk_size: int | None = 256,
  707. ) -> torch.Tensor:
  708. """
  709. Args:
  710. x:
  711. [*, N_res, N_res, C_z] input tensor
  712. mask:
  713. [*, N_res, N_res] input mask
  714. Returns:
  715. [*, N_res, N_res, C_z] output tensor
  716. """
  717. if inplace_safe:
  718. x = self._inference_forward(
  719. z,
  720. mask,
  721. inplace_chunk_size=_inplace_chunk_size,
  722. with_add=_add_with_inplace,
  723. )
  724. return x
  725. if mask is None:
  726. mask = z.new_ones(z.shape[:-1])
  727. mask = mask.unsqueeze(-1)
  728. z = self.layer_norm_in(z)
  729. a = mask
  730. a = a * self.sigmoid(self.linear_a_g(z))
  731. a = a * self.linear_a_p(z)
  732. b = mask
  733. b = b * self.sigmoid(self.linear_b_g(z))
  734. b = b * self.linear_b_p(z)
  735. device_type = a.device.type if a.device.type != "mps" else "cpu"
  736. if is_fp16_enabled(device_type):
  737. with maybe_autocast(device_type=device_type, enabled=False):
  738. x = self._combine_projections(a.float(), b.float())
  739. else:
  740. x = self._combine_projections(a, b)
  741. del a, b
  742. x = self.layer_norm_out(x)
  743. x = self.linear_z(x)
  744. g = self.sigmoid(self.linear_g(z))
  745. x = x * g
  746. return x
  747. class EsmFoldPreTrainedModel(EsmPreTrainedModel):
  748. """
  749. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  750. models.
  751. """
  752. # Subclass `EsMPreTrainedModel` to deal with special init
  753. @torch.no_grad()
  754. def _init_weights(self, module):
  755. """Initialize the weights"""
  756. if isinstance(module, EsmFoldLinear):
  757. with torch.no_grad():
  758. if module.init_fn is not None:
  759. module.init_fn(module.weight, module.bias)
  760. elif module.init == "default":
  761. shape = module.weight.shape
  762. scale = 1.0 / max(1, shape[1])
  763. std = math.sqrt(scale)
  764. init.normal_(module.weight, std=std)
  765. elif module.init == "relu":
  766. shape = module.weight.shape
  767. scale = 2.0 / max(1, shape[1])
  768. std = math.sqrt(scale)
  769. init.normal_(module.weight, std=std)
  770. elif module.init == "glorot":
  771. init.xavier_uniform_(module.weight, gain=1)
  772. elif module.init == "gating":
  773. init.zeros_(module.weight)
  774. if module.bias:
  775. init.ones_(module.bias)
  776. elif module.init == "normal":
  777. init.kaiming_normal_(module.weight, nonlinearity="linear")
  778. elif module.init == "final":
  779. init.zeros_(module.weight)
  780. elif isinstance(module, EsmFoldInvariantPointAttention):
  781. softplus_inverse_1 = 0.541324854612918
  782. init.constant_(module.head_weights, softplus_inverse_1)
  783. elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
  784. init.zeros_(module.tri_mul_in.linear_z.weight)
  785. init.zeros_(module.tri_mul_in.linear_z.bias)
  786. init.zeros_(module.tri_mul_out.linear_z.weight)
  787. init.zeros_(module.tri_mul_out.linear_z.bias)
  788. init.zeros_(module.tri_att_start.mha.linear_o.weight)
  789. init.zeros_(module.tri_att_start.mha.linear_o.bias)
  790. init.zeros_(module.tri_att_end.mha.linear_o.weight)
  791. init.zeros_(module.tri_att_end.mha.linear_o.bias)
  792. init.zeros_(module.sequence_to_pair.o_proj.weight)
  793. init.zeros_(module.sequence_to_pair.o_proj.bias)
  794. init.zeros_(module.pair_to_sequence.linear.weight)
  795. init.zeros_(module.seq_attention.o_proj.weight)
  796. init.zeros_(module.seq_attention.o_proj.bias)
  797. init.zeros_(module.mlp_seq.mlp[-2].weight)
  798. init.zeros_(module.mlp_seq.mlp[-2].bias)
  799. init.zeros_(module.mlp_pair.mlp[-2].weight)
  800. init.zeros_(module.mlp_pair.mlp[-2].bias)
  801. else:
  802. super()._init_weights(module)
  803. class EsmFoldSelfAttention(nn.Module):
  804. def __init__(self, embed_dim, num_heads, head_width, gated=False):
  805. super().__init__()
  806. assert embed_dim == num_heads * head_width
  807. self.embed_dim = embed_dim
  808. self.num_heads = num_heads
  809. self.head_width = head_width
  810. self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
  811. self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
  812. self.gated = gated
  813. if gated:
  814. self.g_proj = nn.Linear(embed_dim, embed_dim)
  815. init.zeros_(self.g_proj.weight)
  816. init.ones_(self.g_proj.bias)
  817. self.rescale_factor = self.head_width**-0.5
  818. init.zeros_(self.o_proj.bias)
  819. def forward(self, x, mask=None, bias=None, indices=None):
  820. """
  821. Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths,
  822. use mask.
  823. Inputs:
  824. x: batch of input sequences (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
  825. x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
  826. Outputs:
  827. sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
  828. """
  829. t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
  830. t = t.permute(0, 2, 1, 3)
  831. q, k, v = t.chunk(3, dim=-1)
  832. q = self.rescale_factor * q
  833. a = torch.einsum("...qc,...kc->...qk", q, k)
  834. # Add external attention bias.
  835. if bias is not None:
  836. a = a + bias.permute(0, 3, 1, 2)
  837. # Do not attend to padding tokens.
  838. if mask is not None:
  839. mask = mask[:, None, None]
  840. a = a.masked_fill(mask == False, -np.inf) # noqa: E712
  841. a = nn.functional.softmax(a, dim=-1)
  842. y = torch.einsum("...hqk,...hkc->...qhc", a, v)
  843. y = y.reshape(*y.shape[:2], -1)
  844. if self.gated:
  845. y = self.g_proj(x).sigmoid() * y
  846. y = self.o_proj(y)
  847. return y, a.permute(0, 3, 1, 2)
  848. class EsmFoldDropout(nn.Module):
  849. """
  850. Implementation of dropout with the ability to share the dropout mask along a particular dimension.
  851. """
  852. def __init__(self, r: float, batch_dim: int | list[int]):
  853. super().__init__()
  854. self.r = r
  855. if isinstance(batch_dim, int):
  856. batch_dim = [batch_dim]
  857. self.batch_dim = batch_dim
  858. self.dropout = nn.Dropout(self.r)
  859. def forward(self, x: torch.Tensor) -> torch.Tensor:
  860. shape = list(x.shape)
  861. if self.batch_dim is not None:
  862. for bd in self.batch_dim:
  863. shape[bd] = 1
  864. return x * self.dropout(x.new_ones(shape))
  865. class EsmFoldSequenceToPair(nn.Module):
  866. def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
  867. super().__init__()
  868. self.layernorm = nn.LayerNorm(sequence_state_dim)
  869. self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
  870. self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
  871. init.zeros_(self.proj.bias)
  872. init.zeros_(self.o_proj.bias)
  873. def forward(self, sequence_state):
  874. """
  875. Inputs:
  876. sequence_state: B x L x sequence_state_dim
  877. Output:
  878. pairwise_state: B x L x L x pairwise_state_dim
  879. Intermediate state:
  880. B x L x L x 2*inner_dim
  881. """
  882. assert len(sequence_state.shape) == 3
  883. s = self.layernorm(sequence_state)
  884. s = self.proj(s)
  885. q, k = s.chunk(2, dim=-1)
  886. prod = q[:, None, :, :] * k[:, :, None, :]
  887. diff = q[:, None, :, :] - k[:, :, None, :]
  888. x = torch.cat([prod, diff], dim=-1)
  889. x = self.o_proj(x)
  890. return x
  891. class EsmFoldPairToSequence(nn.Module):
  892. def __init__(self, pairwise_state_dim, num_heads):
  893. super().__init__()
  894. self.layernorm = nn.LayerNorm(pairwise_state_dim)
  895. self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
  896. def forward(self, pairwise_state):
  897. """
  898. Inputs:
  899. pairwise_state: B x L x L x pairwise_state_dim
  900. Output:
  901. pairwise_bias: B x L x L x num_heads
  902. """
  903. assert len(pairwise_state.shape) == 4
  904. z = self.layernorm(pairwise_state)
  905. pairwise_bias = self.linear(z)
  906. return pairwise_bias
  907. class EsmFoldResidueMLP(nn.Module):
  908. def __init__(self, embed_dim, inner_dim, dropout=0):
  909. super().__init__()
  910. self.mlp = nn.Sequential(
  911. nn.LayerNorm(embed_dim),
  912. nn.Linear(embed_dim, inner_dim),
  913. nn.ReLU(),
  914. nn.Linear(inner_dim, embed_dim),
  915. nn.Dropout(dropout),
  916. )
  917. def forward(self, x):
  918. return x + self.mlp(x)
  919. class EsmFoldTriangularSelfAttentionBlock(nn.Module):
  920. def __init__(self, config):
  921. super().__init__()
  922. self.config = config
  923. sequence_state_dim = config.sequence_state_dim
  924. pairwise_state_dim = config.pairwise_state_dim
  925. sequence_num_heads = sequence_state_dim // config.sequence_head_width
  926. pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
  927. self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
  928. self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
  929. self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
  930. self.seq_attention = EsmFoldSelfAttention(
  931. sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
  932. )
  933. self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
  934. self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
  935. self.tri_att_start = EsmFoldTriangleAttention(
  936. pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
  937. )
  938. self.tri_att_end = EsmFoldTriangleAttention(
  939. pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
  940. )
  941. self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
  942. self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
  943. self.drop = nn.Dropout(config.dropout)
  944. self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
  945. self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
  946. def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
  947. """
  948. Inputs:
  949. sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean
  950. tensor of valid positions
  951. Output:
  952. sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim
  953. """
  954. if len(sequence_state.shape) != 3:
  955. raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.")
  956. if len(pairwise_state.shape) != 4:
  957. raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.")
  958. if mask is not None and len(mask.shape) != 2:
  959. raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
  960. batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
  961. pairwise_state_dim = pairwise_state.shape[3]
  962. if sequence_state_dim != self.config.sequence_state_dim:
  963. raise ValueError(
  964. "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got "
  965. f"{sequence_state_dim} != {self.config.sequence_state_dim}."
  966. )
  967. if pairwise_state_dim != self.config.pairwise_state_dim:
  968. raise ValueError(
  969. "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got "
  970. f"{pairwise_state_dim} != {self.config.pairwise_state_dim}."
  971. )
  972. if batch_dim != pairwise_state.shape[0]:
  973. raise ValueError(
  974. f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != "
  975. f"{pairwise_state.shape[0]}."
  976. )
  977. if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
  978. raise ValueError(
  979. f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != "
  980. f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}."
  981. )
  982. # Update sequence state
  983. bias = self.pair_to_sequence(pairwise_state)
  984. # Self attention with bias + mlp.
  985. y = self.layernorm_1(sequence_state)
  986. y, _ = self.seq_attention(y, mask=mask, bias=bias)
  987. sequence_state = sequence_state + self.drop(y)
  988. sequence_state = self.mlp_seq(sequence_state)
  989. # Update pairwise state
  990. pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
  991. # Axial attention with triangular bias.
  992. tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
  993. pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask))
  994. pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask))
  995. pairwise_state = pairwise_state + self.row_drop(
  996. self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
  997. )
  998. pairwise_state = pairwise_state + self.col_drop(
  999. self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
  1000. )
  1001. # MLP over pairs.
  1002. pairwise_state = self.mlp_pair(pairwise_state)
  1003. return sequence_state, pairwise_state
  1004. class EsmCategoricalMixture:
  1005. def __init__(self, param, bins=50, start=0, end=1):
  1006. # All tensors are of shape ..., bins.
  1007. self.logits = param
  1008. bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
  1009. self.v_bins = (bins[:-1] + bins[1:]) / 2
  1010. def log_prob(self, true):
  1011. # Shapes are:
  1012. # self.probs: ... x bins
  1013. # true : ...
  1014. true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
  1015. nll = self.logits.log_softmax(-1)
  1016. return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
  1017. def mean(self):
  1018. return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
  1019. def categorical_lddt(logits, bins=50):
  1020. # Logits are ..., 37, bins.
  1021. return EsmCategoricalMixture(logits, bins=bins).mean()
  1022. def get_axial_mask(mask):
  1023. """
  1024. Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
  1025. Input:
  1026. mask: B x L tensor of booleans
  1027. Output:
  1028. mask: B x L x L tensor of booleans
  1029. """
  1030. if mask is None:
  1031. return None
  1032. if len(mask.shape) != 2:
  1033. raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
  1034. batch_dim, seq_dim = mask.shape
  1035. m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
  1036. m = m.reshape(batch_dim * seq_dim, seq_dim)
  1037. return m
  1038. class EsmFoldRelativePosition(nn.Module):
  1039. def __init__(self, config):
  1040. super().__init__()
  1041. self.bins = config.position_bins
  1042. # Note an additional offset is used so that the 0th position
  1043. # is reserved for masked pairs.
  1044. self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
  1045. def forward(self, residue_index, mask=None):
  1046. """
  1047. Input:
  1048. residue_index: B x L tensor of indices (dtype=torch.long) mask: B x L tensor of booleans
  1049. Output:
  1050. pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
  1051. """
  1052. if residue_index.dtype != torch.long:
  1053. raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
  1054. if mask is not None and residue_index.shape != mask.shape:
  1055. raise ValueError(
  1056. f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
  1057. )
  1058. diff = residue_index[:, None, :] - residue_index[:, :, None]
  1059. diff = diff.clamp(-self.bins, self.bins)
  1060. diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
  1061. if mask is not None:
  1062. mask = mask[:, None, :] * mask[:, :, None]
  1063. diff[mask == False] = 0 # noqa: E712
  1064. output = self.embedding(diff)
  1065. return output
  1066. class EsmFoldAngleResnetBlock(nn.Module):
  1067. def __init__(self, config):
  1068. super().__init__()
  1069. self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
  1070. self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
  1071. self.relu = nn.ReLU()
  1072. def forward(self, a: torch.Tensor) -> torch.Tensor:
  1073. s_initial = a
  1074. a = self.relu(a)
  1075. a = self.linear_1(a)
  1076. a = self.relu(a)
  1077. a = self.linear_2(a)
  1078. return a + s_initial
  1079. class EsmFoldAngleResnet(nn.Module):
  1080. """
  1081. Implements Algorithm 20, lines 11-14
  1082. """
  1083. def __init__(self, config):
  1084. super().__init__()
  1085. self.config = config
  1086. self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
  1087. self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
  1088. self.layers = nn.ModuleList()
  1089. for _ in range(config.num_resnet_blocks):
  1090. layer = EsmFoldAngleResnetBlock(config)
  1091. self.layers.append(layer)
  1092. self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
  1093. self.relu = nn.ReLU()
  1094. def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  1095. """
  1096. Args:
  1097. s:
  1098. [*, C_hidden] single embedding
  1099. s_initial:
  1100. [*, C_hidden] single embedding as of the start of the StructureModule
  1101. Returns:
  1102. [*, no_angles, 2] predicted angles
  1103. """
  1104. # NOTE: The ReLU's applied to the inputs are absent from the supplement
  1105. # pseudocode but present in the source. For maximal compatibility with
  1106. # the pretrained weights, I'm going with the source.
  1107. # [*, C_hidden]
  1108. s_initial = self.relu(s_initial)
  1109. s_initial = self.linear_initial(s_initial)
  1110. s = self.relu(s)
  1111. s = self.linear_in(s)
  1112. s = s + s_initial
  1113. for l in self.layers:
  1114. s = l(s)
  1115. s = self.relu(s)
  1116. # [*, no_angles * 2]
  1117. s = self.linear_out(s)
  1118. # [*, no_angles, 2]
  1119. s = s.view(s.shape[:-1] + (-1, 2))
  1120. unnormalized_s = s
  1121. norm_denom = torch.sqrt(
  1122. torch.clamp(
  1123. torch.sum(s**2, dim=-1, keepdim=True),
  1124. min=self.config.epsilon,
  1125. )
  1126. )
  1127. s = s / norm_denom
  1128. return unnormalized_s, s
  1129. class EsmFoldInvariantPointAttention(nn.Module):
  1130. """
  1131. Implements Algorithm 22.
  1132. """
  1133. def __init__(self, config):
  1134. super().__init__()
  1135. self.config = config
  1136. c_s = config.sequence_dim
  1137. c_z = config.pairwise_dim
  1138. self.hidden_dim = config.ipa_dim
  1139. self.num_heads = config.num_heads_ipa
  1140. self.num_qk_points = config.num_qk_points
  1141. self.num_v_points = config.num_v_points
  1142. # These linear layers differ from their specifications in the
  1143. # supplement. There, they lack bias and use Glorot initialization.
  1144. # Here as in the official source, they have bias and use the default
  1145. # Lecun initialization.
  1146. hc = config.ipa_dim * config.num_heads_ipa
  1147. self.linear_q = EsmFoldLinear(c_s, hc)
  1148. self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
  1149. hpq = config.num_heads_ipa * config.num_qk_points * 3
  1150. self.linear_q_points = EsmFoldLinear(c_s, hpq)
  1151. hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
  1152. self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
  1153. self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
  1154. self.head_weights = nn.Parameter(torch.zeros(config.num_heads_ipa))
  1155. concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
  1156. self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
  1157. self.softmax = nn.Softmax(dim=-1)
  1158. self.softplus = nn.Softplus()
  1159. def forward(
  1160. self,
  1161. s: torch.Tensor,
  1162. z: torch.Tensor | None,
  1163. r: Rigid,
  1164. mask: torch.Tensor,
  1165. _offload_inference: bool = False,
  1166. _z_reference_list: Sequence[torch.Tensor] | None = None,
  1167. ) -> torch.Tensor:
  1168. """
  1169. Args:
  1170. s:
  1171. [*, N_res, C_s] single representation
  1172. z:
  1173. [*, N_res, N_res, C_z] pair representation
  1174. r:
  1175. [*, N_res] transformation object
  1176. mask:
  1177. [*, N_res] mask
  1178. Returns:
  1179. [*, N_res, C_s] single representation update
  1180. """
  1181. z = [z]
  1182. #######################################
  1183. # Generate scalar and point activations
  1184. #######################################
  1185. # [*, N_res, H * C_hidden]
  1186. q = self.linear_q(s)
  1187. kv = self.linear_kv(s)
  1188. # [*, N_res, H, C_hidden]
  1189. q = q.view(q.shape[:-1] + (self.num_heads, -1))
  1190. # [*, N_res, H, 2 * C_hidden]
  1191. kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
  1192. # [*, N_res, H, C_hidden]
  1193. k, v = torch.split(kv, self.hidden_dim, dim=-1)
  1194. # [*, N_res, H * P_q * 3]
  1195. q_pts = self.linear_q_points(s)
  1196. # This is kind of clunky, but it's how the original does it
  1197. # [*, N_res, H * P_q, 3]
  1198. q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
  1199. q_pts = torch.stack(q_pts, dim=-1)
  1200. q_pts = r[..., None].apply(q_pts)
  1201. # [*, N_res, H, P_q, 3]
  1202. q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3))
  1203. # [*, N_res, H * (P_q + P_v) * 3]
  1204. kv_pts = self.linear_kv_points(s)
  1205. # [*, N_res, H * (P_q + P_v), 3]
  1206. kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
  1207. kv_pts = torch.stack(kv_pts, dim=-1)
  1208. kv_pts = r[..., None].apply(kv_pts)
  1209. # [*, N_res, H, (P_q + P_v), 3]
  1210. kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
  1211. # [*, N_res, H, P_q/P_v, 3]
  1212. k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
  1213. ##########################
  1214. # Compute attention scores
  1215. ##########################
  1216. # [*, N_res, N_res, H]
  1217. b = self.linear_b(z[0])
  1218. if _offload_inference:
  1219. assert sys.getrefcount(z[0]) == 2
  1220. z[0] = z[0].cpu()
  1221. # [*, H, N_res, N_res]
  1222. device_type = q.device.type if q.device.type != "mps" else "cpu"
  1223. if is_fp16_enabled(device_type):
  1224. with maybe_autocast(device_type=device_type, enabled=False):
  1225. a = torch.matmul(
  1226. permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
  1227. permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
  1228. )
  1229. else:
  1230. a = torch.matmul(
  1231. permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
  1232. permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
  1233. )
  1234. a *= math.sqrt(1.0 / (3 * self.hidden_dim))
  1235. a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
  1236. # [*, N_res, N_res, H, P_q, 3]
  1237. pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
  1238. pt_att = pt_att**2
  1239. # [*, N_res, N_res, H, P_q]
  1240. pt_att = sum(torch.unbind(pt_att, dim=-1))
  1241. head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1)))
  1242. head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
  1243. pt_att = pt_att * head_weights
  1244. # [*, N_res, N_res, H]
  1245. pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
  1246. # [*, N_res, N_res]
  1247. square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
  1248. square_mask = self.config.inf * (square_mask - 1)
  1249. # [*, H, N_res, N_res]
  1250. pt_att = permute_final_dims(pt_att, (2, 0, 1))
  1251. a = a + pt_att
  1252. a = a + square_mask.unsqueeze(-3)
  1253. a = self.softmax(a)
  1254. ################
  1255. # Compute output
  1256. ################
  1257. # [*, N_res, H, C_hidden]
  1258. o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
  1259. # [*, N_res, H * C_hidden]
  1260. o = flatten_final_dims(o, 2)
  1261. # [*, H, 3, N_res, P_v]
  1262. o_pt = torch.sum(
  1263. (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
  1264. dim=-2,
  1265. )
  1266. # [*, N_res, H, P_v, 3]
  1267. o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
  1268. o_pt = r[..., None, None].invert_apply(o_pt)
  1269. # [*, N_res, H * P_v]
  1270. o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2)
  1271. # [*, N_res, H * P_v, 3]
  1272. o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
  1273. if _offload_inference:
  1274. z[0] = z[0].to(o_pt.device)
  1275. # [*, N_res, H, C_z]
  1276. o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
  1277. # [*, N_res, H * C_z]
  1278. o_pair = flatten_final_dims(o_pair, 2)
  1279. # [*, N_res, C_s]
  1280. s = self.linear_out(
  1281. torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
  1282. )
  1283. return s
  1284. class EsmFoldBackboneUpdate(nn.Module):
  1285. """
  1286. Implements part of Algorithm 23.
  1287. """
  1288. def __init__(self, config):
  1289. super().__init__()
  1290. self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
  1291. def forward(self, s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  1292. """
  1293. Args:
  1294. [*, N_res, C_s] single representation
  1295. Returns:
  1296. [*, N_res, 6] update vector
  1297. """
  1298. # [*, 6]
  1299. update = self.linear(s)
  1300. return update
  1301. class EsmFoldStructureModuleTransitionLayer(nn.Module):
  1302. def __init__(self, config):
  1303. super().__init__()
  1304. self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
  1305. self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
  1306. self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
  1307. self.relu = nn.ReLU()
  1308. def forward(self, s):
  1309. s_initial = s
  1310. s = self.linear_1(s)
  1311. s = self.relu(s)
  1312. s = self.linear_2(s)
  1313. s = self.relu(s)
  1314. s = self.linear_3(s)
  1315. s = s + s_initial
  1316. return s
  1317. class EsmFoldStructureModuleTransition(nn.Module):
  1318. def __init__(self, config):
  1319. super().__init__()
  1320. self.config = config
  1321. self.layers = nn.ModuleList()
  1322. for _ in range(config.num_transition_layers):
  1323. l = EsmFoldStructureModuleTransitionLayer(config)
  1324. self.layers.append(l)
  1325. self.dropout = nn.Dropout(config.dropout_rate)
  1326. self.layer_norm = LayerNorm(config.sequence_dim)
  1327. def forward(self, s):
  1328. for l in self.layers:
  1329. s = l(s)
  1330. s = self.dropout(s)
  1331. s = self.layer_norm(s)
  1332. return s
  1333. class EsmFoldStructureModule(nn.Module):
  1334. def __init__(self, config):
  1335. super().__init__()
  1336. self.config = config
  1337. # Buffers to be lazily initialized later
  1338. # self.default_frames
  1339. # self.group_idx
  1340. # self.atom_mask
  1341. # self.lit_positions
  1342. self.layer_norm_s = LayerNorm(config.sequence_dim)
  1343. self.layer_norm_z = LayerNorm(config.pairwise_dim)
  1344. self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
  1345. self.ipa = EsmFoldInvariantPointAttention(config)
  1346. self.ipa_dropout = nn.Dropout(config.dropout_rate)
  1347. self.layer_norm_ipa = LayerNorm(config.sequence_dim)
  1348. self.transition = EsmFoldStructureModuleTransition(config)
  1349. self.bb_update = EsmFoldBackboneUpdate(config)
  1350. self.angle_resnet = EsmFoldAngleResnet(config)
  1351. def forward(
  1352. self,
  1353. evoformer_output_dict,
  1354. aatype,
  1355. mask=None,
  1356. _offload_inference=False,
  1357. ):
  1358. """
  1359. Args:
  1360. evoformer_output_dict:
  1361. Dictionary containing:
  1362. "single":
  1363. [*, N_res, C_s] single representation
  1364. "pair":
  1365. [*, N_res, N_res, C_z] pair representation
  1366. aatype:
  1367. [*, N_res] amino acid indices
  1368. mask:
  1369. Optional [*, N_res] sequence mask
  1370. Returns:
  1371. A dictionary of outputs
  1372. """
  1373. s = evoformer_output_dict["single"]
  1374. if mask is None:
  1375. # [*, N]
  1376. mask = s.new_ones(s.shape[:-1])
  1377. # [*, N, C_s]
  1378. s = self.layer_norm_s(s)
  1379. # [*, N, N, C_z]
  1380. z = self.layer_norm_z(evoformer_output_dict["pair"])
  1381. z_reference_list = None
  1382. if _offload_inference:
  1383. assert sys.getrefcount(evoformer_output_dict["pair"]) == 2
  1384. evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
  1385. z_reference_list = [z]
  1386. z = None
  1387. # [*, N, C_s]
  1388. s_initial = s
  1389. s = self.linear_in(s)
  1390. # [*, N]
  1391. rigids = Rigid.identity(
  1392. s.shape[:-1],
  1393. s.dtype,
  1394. s.device,
  1395. self.training,
  1396. fmt="quat",
  1397. )
  1398. outputs = []
  1399. for i in range(self.config.num_blocks):
  1400. # [*, N, C_s]
  1401. s = s + self.ipa(
  1402. s,
  1403. z,
  1404. rigids,
  1405. mask,
  1406. _offload_inference=_offload_inference,
  1407. _z_reference_list=z_reference_list,
  1408. )
  1409. s = self.ipa_dropout(s)
  1410. s = self.layer_norm_ipa(s)
  1411. s = self.transition(s)
  1412. # [*, N]
  1413. rigids = rigids.compose_q_update_vec(self.bb_update(s))
  1414. # To hew as closely as possible to AlphaFold, we convert our
  1415. # quaternion-based transformations to rotation-matrix ones
  1416. # here
  1417. backb_to_global = Rigid(
  1418. Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None),
  1419. rigids.get_trans(),
  1420. )
  1421. backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor)
  1422. # [*, N, 7, 2]
  1423. unnormalized_angles, angles = self.angle_resnet(s, s_initial)
  1424. all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype)
  1425. pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype)
  1426. scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor)
  1427. preds = {
  1428. "frames": scaled_rigids.to_tensor_7(),
  1429. "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
  1430. "unnormalized_angles": unnormalized_angles,
  1431. "angles": angles,
  1432. "positions": pred_xyz,
  1433. "states": s,
  1434. }
  1435. outputs.append(preds)
  1436. rigids = rigids.stop_rot_gradient()
  1437. del z, z_reference_list
  1438. if _offload_inference:
  1439. evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
  1440. outputs = dict_multimap(torch.stack, outputs)
  1441. outputs["single"] = s
  1442. return outputs
  1443. def _init_residue_constants(self, float_dtype, device):
  1444. if not hasattr(self, "default_frames"):
  1445. self.register_buffer(
  1446. "default_frames",
  1447. torch.tensor(
  1448. residue_constants.restype_rigid_group_default_frame,
  1449. dtype=float_dtype,
  1450. device=device,
  1451. requires_grad=False,
  1452. ),
  1453. persistent=False,
  1454. )
  1455. if not hasattr(self, "group_idx"):
  1456. self.register_buffer(
  1457. "group_idx",
  1458. torch.tensor(
  1459. residue_constants.restype_atom14_to_rigid_group,
  1460. device=device,
  1461. requires_grad=False,
  1462. ),
  1463. persistent=False,
  1464. )
  1465. if not hasattr(self, "atom_mask"):
  1466. self.register_buffer(
  1467. "atom_mask",
  1468. torch.tensor(
  1469. residue_constants.restype_atom14_mask,
  1470. dtype=float_dtype,
  1471. device=device,
  1472. requires_grad=False,
  1473. ),
  1474. persistent=False,
  1475. )
  1476. if not hasattr(self, "lit_positions"):
  1477. self.register_buffer(
  1478. "lit_positions",
  1479. torch.tensor(
  1480. residue_constants.restype_atom14_rigid_group_positions,
  1481. dtype=float_dtype,
  1482. device=device,
  1483. requires_grad=False,
  1484. ),
  1485. persistent=False,
  1486. )
  1487. def torsion_angles_to_frames(self, r, alpha, f):
  1488. # Lazily initialize the residue constants on the correct device
  1489. self._init_residue_constants(alpha.dtype, alpha.device)
  1490. # Separated purely to make testing less annoying
  1491. return torsion_angles_to_frames(r, alpha, f, self.default_frames)
  1492. def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N]
  1493. # Lazily initialize the residue constants on the correct device
  1494. self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
  1495. return frames_and_literature_positions_to_atom14_pos(
  1496. r,
  1497. f,
  1498. self.default_frames,
  1499. self.group_idx,
  1500. self.atom_mask,
  1501. self.lit_positions,
  1502. )
  1503. class EsmFoldingTrunk(nn.Module):
  1504. def __init__(self, config):
  1505. super().__init__()
  1506. self.config = config
  1507. c_s = config.sequence_state_dim
  1508. c_z = config.pairwise_state_dim
  1509. self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
  1510. self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
  1511. self.recycle_bins = 15
  1512. self.recycle_s_norm = nn.LayerNorm(c_s)
  1513. self.recycle_z_norm = nn.LayerNorm(c_z)
  1514. self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
  1515. self.recycle_disto.weight[0].detach().zero_()
  1516. self.structure_module = EsmFoldStructureModule(config.structure_module)
  1517. self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
  1518. self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
  1519. self.chunk_size = config.chunk_size
  1520. def set_chunk_size(self, chunk_size):
  1521. # This parameter means the axial attention will be computed
  1522. # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
  1523. # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
  1524. # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks.
  1525. self.chunk_size = chunk_size
  1526. def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
  1527. """
  1528. Inputs:
  1529. seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
  1530. x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
  1531. Output:
  1532. predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
  1533. """
  1534. device = seq_feats.device
  1535. s_s_0 = seq_feats
  1536. s_z_0 = pair_feats
  1537. if no_recycles is None:
  1538. no_recycles = self.config.max_recycles
  1539. else:
  1540. if no_recycles < 0:
  1541. raise ValueError("Number of recycles must not be negative.")
  1542. no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
  1543. def trunk_iter(s, z, residx, mask):
  1544. z = z + self.pairwise_positional_embedding(residx, mask=mask)
  1545. for block in self.blocks:
  1546. s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
  1547. return s, z
  1548. s_s = s_s_0
  1549. s_z = s_z_0
  1550. recycle_s = torch.zeros_like(s_s)
  1551. recycle_z = torch.zeros_like(s_z)
  1552. recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
  1553. for recycle_idx in range(no_recycles):
  1554. with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
  1555. # === Recycling ===
  1556. recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
  1557. recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
  1558. recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
  1559. s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
  1560. # === Structure module ===
  1561. structure = self.structure_module(
  1562. {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
  1563. true_aa,
  1564. mask.float(),
  1565. )
  1566. recycle_s = s_s
  1567. recycle_z = s_z
  1568. # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
  1569. recycle_bins = EsmFoldingTrunk.distogram(
  1570. structure["positions"][-1][:, :, :3],
  1571. 3.375,
  1572. 21.375,
  1573. self.recycle_bins,
  1574. )
  1575. structure["s_s"] = s_s
  1576. structure["s_z"] = s_z
  1577. return structure
  1578. @staticmethod
  1579. def distogram(coords, min_bin, max_bin, num_bins):
  1580. # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
  1581. boundaries = torch.linspace(
  1582. min_bin,
  1583. max_bin,
  1584. num_bins - 1,
  1585. device=coords.device,
  1586. )
  1587. boundaries = boundaries**2
  1588. N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
  1589. # Infer CB coordinates.
  1590. b = CA - N
  1591. c = C - CA
  1592. a = b.cross(c, dim=-1)
  1593. CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
  1594. dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
  1595. bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
  1596. return bins
  1597. # TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare
  1598. # the outputs for downstream use.
  1599. @auto_docstring(
  1600. custom_intro="""
  1601. ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
  1602. by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
  1603. the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
  1604. protein(s).
  1605. """
  1606. )
  1607. class EsmForProteinFolding(EsmPreTrainedModel):
  1608. _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
  1609. _supports_flash_attn = False
  1610. _supports_sdpa = False
  1611. _supports_attention_backend = False
  1612. _can_record_outputs = None
  1613. def _init_weights(self, module):
  1614. super()._init_weights(module)
  1615. if isinstance(module, EsmForProteinFolding):
  1616. init.copy_(module.af2_to_esm, module._af2_to_esm_from_vocab_list(module.config.vocab_list))
  1617. def __init__(self, config):
  1618. super().__init__(config)
  1619. self.config = config
  1620. self.distogram_bins = 64
  1621. self.esm = EsmModel(config, add_pooling_layer=False)
  1622. self.esm.requires_grad_(False)
  1623. if self.config.esmfold_config.fp16_esm:
  1624. self.esm.half()
  1625. self.esm_feats = self.config.hidden_size
  1626. self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
  1627. self.esm_layers = self.config.num_hidden_layers
  1628. self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
  1629. self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
  1630. trunk_config = self.config.esmfold_config.trunk
  1631. c_s = trunk_config.sequence_state_dim
  1632. c_z = trunk_config.pairwise_state_dim
  1633. self.esm_s_mlp = nn.Sequential(
  1634. LayerNorm(self.esm_feats),
  1635. nn.Linear(self.esm_feats, c_s),
  1636. nn.ReLU(),
  1637. nn.Linear(c_s, c_s),
  1638. )
  1639. # 0 is padding, N is unknown residues, N + 1 is mask.
  1640. self.n_tokens_embed = residue_constants.restype_num + 3
  1641. self.pad_idx = 0
  1642. self.unk_idx = self.n_tokens_embed - 2
  1643. self.mask_idx = self.n_tokens_embed - 1
  1644. self.esm_dict_cls_idx = self.config.vocab_list.index("<cls>")
  1645. self.esm_dict_mask_idx = self.config.vocab_list.index("<mask>")
  1646. self.esm_dict_eos_idx = self.config.vocab_list.index("<eos>")
  1647. self.esm_dict_padding_idx = self.config.vocab_list.index("<pad>")
  1648. if self.config.esmfold_config.embed_aa:
  1649. self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
  1650. self.trunk = EsmFoldingTrunk(trunk_config)
  1651. self.distogram_head = nn.Linear(c_z, self.distogram_bins)
  1652. self.ptm_head = nn.Linear(c_z, self.distogram_bins)
  1653. self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
  1654. self.lddt_bins = 50
  1655. structure_module_config = trunk_config.structure_module
  1656. self.lddt_head = nn.Sequential(
  1657. nn.LayerNorm(structure_module_config.sequence_dim),
  1658. nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
  1659. nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
  1660. nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
  1661. )
  1662. self.post_init()
  1663. @staticmethod
  1664. def _af2_to_esm_from_vocab_list(vocab_list: list[str]) -> torch.Tensor:
  1665. # Remember that t is shifted from residue_constants by 1 (0 is padding).
  1666. esm_reorder = [vocab_list.index("<pad>")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x]
  1667. return torch.tensor(esm_reorder)
  1668. @auto_docstring
  1669. def forward(
  1670. self,
  1671. input_ids: torch.Tensor,
  1672. attention_mask: torch.Tensor | None = None,
  1673. position_ids: torch.Tensor | None = None,
  1674. masking_pattern: torch.Tensor | None = None,
  1675. num_recycles: int | None = None,
  1676. output_hidden_states: bool | None = False,
  1677. **kwargs,
  1678. ) -> EsmForProteinFoldingOutput:
  1679. r"""
  1680. masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1681. Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
  1682. num_recycles (`int`, *optional*, defaults to `None`):
  1683. Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
  1684. consists of passing the output of the folding trunk back in as input to the trunk. During training, the
  1685. number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
  1686. after each recycle. During inference, num_recycles should be set to the highest value that the model was
  1687. trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
  1688. used.
  1689. Example:
  1690. ```python
  1691. >>> from transformers import AutoTokenizer, EsmForProteinFolding
  1692. >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
  1693. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
  1694. >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide
  1695. >>> outputs = model(**inputs)
  1696. >>> folded_positions = outputs.positions
  1697. ```
  1698. """
  1699. cfg = self.config.esmfold_config
  1700. aa = input_ids # B x L
  1701. B = aa.shape[0]
  1702. L = aa.shape[1]
  1703. device = input_ids.device
  1704. if attention_mask is None:
  1705. attention_mask = torch.ones_like(aa, device=device)
  1706. if position_ids is None:
  1707. position_ids = torch.arange(L, device=device).expand_as(input_ids)
  1708. # === ESM ===
  1709. esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)
  1710. if masking_pattern is not None:
  1711. masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
  1712. else:
  1713. masked_aa = aa
  1714. mlm_targets = None
  1715. # We get sequence and pair representations from whatever version of ESM /
  1716. # configuration we are using. The sequence representation esm_s is always
  1717. # present. The pair embedding esm_z may be present depending on the
  1718. # configuration of the model. If esm_z is not used by the model then it
  1719. # is returned as None here.
  1720. esm_s = self.compute_language_model_representations(esmaa)
  1721. # Convert esm_s and esm_z, if present, to the precision used by the trunk and
  1722. # the structure module. These tensors may be a lower precision if, for example,
  1723. # we're running the language model in fp16 precision.
  1724. esm_s = esm_s.to(self.esm_s_combine.dtype)
  1725. if cfg.esm_ablate_sequence:
  1726. esm_s = esm_s * 0
  1727. esm_s = esm_s.detach()
  1728. # === preprocessing ===
  1729. esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
  1730. s_s_0 = self.esm_s_mlp(esm_s)
  1731. s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)
  1732. if self.config.esmfold_config.embed_aa:
  1733. s_s_0 += self.embedding(masked_aa)
  1734. structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles)
  1735. # Documenting what we expect:
  1736. structure = {
  1737. k: v
  1738. for k, v in structure.items()
  1739. if k
  1740. in [
  1741. "s_z",
  1742. "s_s",
  1743. "frames",
  1744. "sidechain_frames",
  1745. "unnormalized_angles",
  1746. "angles",
  1747. "positions",
  1748. "states",
  1749. ]
  1750. }
  1751. # Add BERT mask for the loss to use, if available.
  1752. if mlm_targets:
  1753. structure["mlm_targets"] = mlm_targets
  1754. disto_logits = self.distogram_head(structure["s_z"])
  1755. disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
  1756. structure["distogram_logits"] = disto_logits
  1757. lm_logits = self.lm_head(structure["s_s"])
  1758. structure["lm_logits"] = lm_logits
  1759. structure["aatype"] = aa
  1760. make_atom14_masks(structure)
  1761. # Of course, this doesn't respect the true mask because it doesn't know about it...
  1762. # We're not going to properly mask change of index tensors:
  1763. # "residx_atom14_to_atom37",
  1764. # "residx_atom37_to_atom14",
  1765. for k in [
  1766. "atom14_atom_exists",
  1767. "atom37_atom_exists",
  1768. ]:
  1769. structure[k] *= attention_mask.unsqueeze(-1)
  1770. structure["residue_index"] = position_ids
  1771. lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins)
  1772. structure["lddt_head"] = lddt_head
  1773. plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
  1774. structure["plddt"] = plddt
  1775. ptm_logits = self.ptm_head(structure["s_z"])
  1776. structure["ptm_logits"] = ptm_logits
  1777. structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
  1778. structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins))
  1779. return EsmForProteinFoldingOutput(**structure)
  1780. def af2_idx_to_esm_idx(self, aa, mask):
  1781. # avoid indexing on different devices
  1782. if self.af2_to_esm.device != aa.device:
  1783. self.af2_to_esm = self.af2_to_esm.to(aa.device)
  1784. aa = (aa + 1).masked_fill(mask != 1, 0)
  1785. return self.af2_to_esm[aa]
  1786. def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
  1787. device = next(self.parameters()).device
  1788. B, L = esmaa.shape # B = batch size, L = sequence length.
  1789. if self.config.esmfold_config.bypass_lm:
  1790. esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
  1791. return esm_s
  1792. bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
  1793. bos = esmaa.new_full((B, 1), bosi)
  1794. eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
  1795. esmaa = torch.cat([bos, esmaa, eos], dim=1)
  1796. # Use the first padding index as eos during inference.
  1797. esmaa[range(B), (esmaa != 1).sum(1)] = eosi
  1798. # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
  1799. # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
  1800. # esm_z is always None
  1801. esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
  1802. esm_s = torch.stack(esm_hidden_states, dim=2)
  1803. esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
  1804. return esm_s
  1805. def bert_mask(self, aa, esmaa, mask, pattern):
  1806. new_aa = aa.clone()
  1807. target = aa.clone()
  1808. new_esmaa = esmaa.clone()
  1809. new_aa[pattern == 1] = self.mask_idx
  1810. target[pattern != 1] = 0
  1811. new_esmaa[pattern == 1] = self.esm_dict_mask_idx
  1812. return new_aa, new_esmaa, target
  1813. @torch.no_grad()
  1814. def infer(
  1815. self,
  1816. seqs: str | list[str],
  1817. position_ids=None,
  1818. ):
  1819. if isinstance(seqs, str):
  1820. lst = [seqs]
  1821. else:
  1822. lst = seqs
  1823. # Returns the raw outputs of the model given an input sequence.
  1824. device = next(self.parameters()).device
  1825. aatype = collate_dense_tensors(
  1826. [
  1827. torch.from_numpy(
  1828. residue_constants.sequence_to_onehot(
  1829. sequence=seq,
  1830. mapping=residue_constants.restype_order_with_x,
  1831. map_unknown_to_x=True,
  1832. )
  1833. )
  1834. .to(device)
  1835. .argmax(dim=1)
  1836. for seq in lst
  1837. ]
  1838. ) # B=1 x L
  1839. mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
  1840. position_ids = (
  1841. torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
  1842. if position_ids is None
  1843. else position_ids.to(device)
  1844. )
  1845. if position_ids.ndim == 1:
  1846. position_ids = position_ids.unsqueeze(0)
  1847. return self.forward(
  1848. aatype,
  1849. mask,
  1850. position_ids=position_ids,
  1851. )
  1852. @staticmethod
  1853. def output_to_pdb(output: dict) -> list[str]:
  1854. """Returns the pdb (file) string from the model given the model output."""
  1855. output = {k: v.to("cpu").numpy() for k, v in output.items()}
  1856. pdbs = []
  1857. final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
  1858. final_atom_mask = output["atom37_atom_exists"]
  1859. for i in range(output["aatype"].shape[0]):
  1860. aa = output["aatype"][i]
  1861. pred_pos = final_atom_positions[i]
  1862. mask = final_atom_mask[i]
  1863. resid = output["residue_index"][i] + 1
  1864. pred = OFProtein(
  1865. aatype=aa,
  1866. atom_positions=pred_pos,
  1867. atom_mask=mask,
  1868. residue_index=resid,
  1869. b_factors=output["plddt"][i],
  1870. )
  1871. pdbs.append(to_pdb(pred))
  1872. return pdbs
  1873. def infer_pdb(self, seqs, *args, **kwargs) -> str:
  1874. """Returns the pdb (file) string from the model given an input sequence."""
  1875. assert isinstance(seqs, str)
  1876. output = self.infer(seqs, *args, **kwargs)
  1877. return self.output_to_pdb(output)[0]
  1878. def infer_pdbs(self, seqs: list[str], *args, **kwargs) -> list[str]:
  1879. """Returns the pdb (file) string from the model given an input sequence."""
  1880. output = self.infer(seqs, *args, **kwargs)
  1881. return self.output_to_pdb(output)
  1882. __all__ = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]