processing_evolla.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Processor class for EVOLLA.
  16. """
  17. from ...feature_extraction_utils import BatchFeature
  18. from ...processing_utils import (
  19. ProcessorMixin,
  20. )
  21. from ...utils import auto_docstring
  22. PROTEIN_VALID_KEYS = ["aa_seq", "foldseek", "msa"]
  23. @auto_docstring
  24. class EvollaProcessor(ProcessorMixin):
  25. def __init__(self, protein_tokenizer, tokenizer=None, protein_max_length=1024, text_max_length=512, **kwargs):
  26. r"""
  27. protein_tokenizer (`EsmTokenizer`):
  28. An instance of [`EsmTokenizer`]. The protein tokenizer is a required input.
  29. protein_max_length (`int`, *optional*, defaults to 1024):
  30. The maximum length of the sequence to be generated.
  31. text_max_length (`int`, *optional*, defaults to 512):
  32. The maximum length of the text to be generated.
  33. """
  34. if protein_tokenizer is None:
  35. raise ValueError("You need to specify an `protein_tokenizer`.")
  36. if tokenizer is None:
  37. raise ValueError("You need to specify a `tokenizer`.")
  38. super().__init__(protein_tokenizer, tokenizer)
  39. self.tokenizer.pad_token = "<|reserved_special_token_0|>"
  40. self.protein_max_length = protein_max_length
  41. self.text_max_length = text_max_length
  42. def process_proteins(self, proteins, protein_max_length=1024):
  43. sa_sequences = []
  44. for protein in proteins:
  45. aa_seq = protein.get("aa_seq")
  46. foldseek = protein.get("foldseek")
  47. sa_sequence = "".join([s.upper() + f.lower() for s, f in zip(aa_seq, foldseek)])
  48. sa_sequences.append(sa_sequence)
  49. sa_tokens = self.protein_tokenizer(
  50. sa_sequences, return_tensors="pt", truncation=True, max_length=protein_max_length, padding=True
  51. )
  52. return sa_tokens
  53. def process_text(
  54. self,
  55. texts,
  56. text_max_length: int = 512,
  57. ):
  58. prompts = []
  59. for messages in texts:
  60. prompt = self.tokenizer.apply_chat_template(
  61. messages,
  62. tokenize=False,
  63. add_generation_prompt=True,
  64. )
  65. prompts.append(prompt)
  66. prompt_inputs = self.tokenizer(
  67. prompts,
  68. add_special_tokens=False,
  69. return_tensors="pt",
  70. padding="longest",
  71. truncation=True,
  72. max_length=text_max_length,
  73. )
  74. return prompt_inputs
  75. @auto_docstring
  76. def __call__(
  77. self,
  78. proteins: list[dict] | dict | None = None,
  79. messages_list: list[list[dict]] | list[dict] | None = None,
  80. protein_max_length: int | None = None,
  81. text_max_length: int | None = None,
  82. **kwargs,
  83. ):
  84. r"""
  85. proteins (`Union[List[dict], dict]`):
  86. A list of dictionaries or a single dictionary containing the following keys:
  87. - `"aa_seq"` (`str`) -- The amino acid sequence of the protein.
  88. - `"foldseek"` (`str`) -- The foldseek string of the protein.
  89. messages_list (`Union[List[List[dict]], List[dict]]`):
  90. A list of lists of dictionaries or a list of dictionaries containing the following keys:
  91. - `"role"` (`str`) -- The role of the message.
  92. - `"content"` (`str`) -- The content of the message.
  93. protein_max_length (`int`, *optional*, defaults to 1024):
  94. The maximum length of the sequence to be generated.
  95. text_max_length (`int`, *optional*, defaults to 512):
  96. The maximum length of the text.
  97. Return:
  98. a dict with following keys:
  99. - `protein_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the protein sequence.
  100. - `protein_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the protein sequence.
  101. - `text_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the text sequence.
  102. - `text_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the text sequence.
  103. """
  104. # proteins and messages_list should be provided
  105. if proteins is None or messages_list is None:
  106. raise ValueError("You need to specify `messages_list` and `proteins`.")
  107. protein_max_length = protein_max_length if protein_max_length is not None else self.protein_max_length
  108. text_max_length = text_max_length if text_max_length is not None else self.text_max_length
  109. # proteins should be List[dict]
  110. if isinstance(proteins, dict):
  111. proteins = [proteins]
  112. # messages_list should be List[List[dict]]
  113. if isinstance(messages_list, (list, tuple)) and not isinstance(messages_list[0], (list, tuple)):
  114. messages_list = [messages_list]
  115. # Check if batched proteins are in the correct format
  116. if isinstance(proteins, (list, tuple)) and not all(isinstance(p, dict) for p in proteins):
  117. raise ValueError("The proteins should be a list of dictionaries, but not all elements are dictionaries.")
  118. if isinstance(proteins, (list, tuple)) and not all(
  119. all(k in PROTEIN_VALID_KEYS for k in p.keys()) for p in proteins
  120. ):
  121. raise ValueError(
  122. "There should be a list of dictionaries with keys: "
  123. f"{', '.join(PROTEIN_VALID_KEYS)} for each protein."
  124. f"But got: {proteins}"
  125. )
  126. # Check if batched messages_list is in the correct format
  127. if isinstance(messages_list, (list, tuple)):
  128. for messages in messages_list:
  129. if not isinstance(messages, (list, tuple)):
  130. raise TypeError(f"Each messages in messages_list should be a list instead of {type(messages)}.")
  131. if not all(isinstance(m, dict) for m in messages):
  132. raise ValueError(
  133. "Each message in messages_list should be a list of dictionaries, but not all elements are dictionaries."
  134. )
  135. if any(len(m.keys()) != 2 for m in messages) or any(
  136. set(m.keys()) != {"role", "content"} for m in messages
  137. ):
  138. raise ValueError(
  139. "Each message in messages_list should be a list of dictionaries with two keys: 'role' and 'content'."
  140. f"But got: {messages}"
  141. )
  142. else:
  143. raise ValueError(
  144. f"The messages_list should be a list of lists of dictionaries, but it's {type(messages_list)}."
  145. )
  146. sa_tokens = self.process_proteins(proteins, protein_max_length)
  147. text_tokens = self.process_text(messages_list, text_max_length)
  148. return BatchFeature(
  149. data={
  150. "protein_input_ids": sa_tokens["input_ids"],
  151. "protein_attention_mask": sa_tokens["attention_mask"],
  152. "input_ids": text_tokens["input_ids"],
  153. "attention_mask": text_tokens["attention_mask"],
  154. }
  155. )
  156. def batch_decode(self, *args, **kwargs):
  157. return self.tokenizer.batch_decode(*args, **kwargs)
  158. def decode(self, *args, **kwargs):
  159. return self.tokenizer.decode(*args, **kwargs)
  160. def protein_batch_decode(self, *args, **kwargs):
  161. return self.protein_tokenizer.batch_decode(*args, **kwargs)
  162. def protein_decode(self, *args, **kwargs):
  163. return self.protein_tokenizer.decode(*args, **kwargs)
  164. __all__ = ["EvollaProcessor"]