| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- # Copyright 2025 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Processor class for EVOLLA.
- """
- from ...feature_extraction_utils import BatchFeature
- from ...processing_utils import (
- ProcessorMixin,
- )
- from ...utils import auto_docstring
- PROTEIN_VALID_KEYS = ["aa_seq", "foldseek", "msa"]
- @auto_docstring
- class EvollaProcessor(ProcessorMixin):
- def __init__(self, protein_tokenizer, tokenizer=None, protein_max_length=1024, text_max_length=512, **kwargs):
- r"""
- protein_tokenizer (`EsmTokenizer`):
- An instance of [`EsmTokenizer`]. The protein tokenizer is a required input.
- protein_max_length (`int`, *optional*, defaults to 1024):
- The maximum length of the sequence to be generated.
- text_max_length (`int`, *optional*, defaults to 512):
- The maximum length of the text to be generated.
- """
- if protein_tokenizer is None:
- raise ValueError("You need to specify an `protein_tokenizer`.")
- if tokenizer is None:
- raise ValueError("You need to specify a `tokenizer`.")
- super().__init__(protein_tokenizer, tokenizer)
- self.tokenizer.pad_token = "<|reserved_special_token_0|>"
- self.protein_max_length = protein_max_length
- self.text_max_length = text_max_length
- def process_proteins(self, proteins, protein_max_length=1024):
- sa_sequences = []
- for protein in proteins:
- aa_seq = protein.get("aa_seq")
- foldseek = protein.get("foldseek")
- sa_sequence = "".join([s.upper() + f.lower() for s, f in zip(aa_seq, foldseek)])
- sa_sequences.append(sa_sequence)
- sa_tokens = self.protein_tokenizer(
- sa_sequences, return_tensors="pt", truncation=True, max_length=protein_max_length, padding=True
- )
- return sa_tokens
- def process_text(
- self,
- texts,
- text_max_length: int = 512,
- ):
- prompts = []
- for messages in texts:
- prompt = self.tokenizer.apply_chat_template(
- messages,
- tokenize=False,
- add_generation_prompt=True,
- )
- prompts.append(prompt)
- prompt_inputs = self.tokenizer(
- prompts,
- add_special_tokens=False,
- return_tensors="pt",
- padding="longest",
- truncation=True,
- max_length=text_max_length,
- )
- return prompt_inputs
- @auto_docstring
- def __call__(
- self,
- proteins: list[dict] | dict | None = None,
- messages_list: list[list[dict]] | list[dict] | None = None,
- protein_max_length: int | None = None,
- text_max_length: int | None = None,
- **kwargs,
- ):
- r"""
- proteins (`Union[List[dict], dict]`):
- A list of dictionaries or a single dictionary containing the following keys:
- - `"aa_seq"` (`str`) -- The amino acid sequence of the protein.
- - `"foldseek"` (`str`) -- The foldseek string of the protein.
- messages_list (`Union[List[List[dict]], List[dict]]`):
- A list of lists of dictionaries or a list of dictionaries containing the following keys:
- - `"role"` (`str`) -- The role of the message.
- - `"content"` (`str`) -- The content of the message.
- protein_max_length (`int`, *optional*, defaults to 1024):
- The maximum length of the sequence to be generated.
- text_max_length (`int`, *optional*, defaults to 512):
- The maximum length of the text.
- Return:
- a dict with following keys:
- - `protein_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the protein sequence.
- - `protein_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the protein sequence.
- - `text_input_ids` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the text sequence.
- - `text_attention_mask` (`torch.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the text sequence.
- """
- # proteins and messages_list should be provided
- if proteins is None or messages_list is None:
- raise ValueError("You need to specify `messages_list` and `proteins`.")
- protein_max_length = protein_max_length if protein_max_length is not None else self.protein_max_length
- text_max_length = text_max_length if text_max_length is not None else self.text_max_length
- # proteins should be List[dict]
- if isinstance(proteins, dict):
- proteins = [proteins]
- # messages_list should be List[List[dict]]
- if isinstance(messages_list, (list, tuple)) and not isinstance(messages_list[0], (list, tuple)):
- messages_list = [messages_list]
- # Check if batched proteins are in the correct format
- if isinstance(proteins, (list, tuple)) and not all(isinstance(p, dict) for p in proteins):
- raise ValueError("The proteins should be a list of dictionaries, but not all elements are dictionaries.")
- if isinstance(proteins, (list, tuple)) and not all(
- all(k in PROTEIN_VALID_KEYS for k in p.keys()) for p in proteins
- ):
- raise ValueError(
- "There should be a list of dictionaries with keys: "
- f"{', '.join(PROTEIN_VALID_KEYS)} for each protein."
- f"But got: {proteins}"
- )
- # Check if batched messages_list is in the correct format
- if isinstance(messages_list, (list, tuple)):
- for messages in messages_list:
- if not isinstance(messages, (list, tuple)):
- raise TypeError(f"Each messages in messages_list should be a list instead of {type(messages)}.")
- if not all(isinstance(m, dict) for m in messages):
- raise ValueError(
- "Each message in messages_list should be a list of dictionaries, but not all elements are dictionaries."
- )
- if any(len(m.keys()) != 2 for m in messages) or any(
- set(m.keys()) != {"role", "content"} for m in messages
- ):
- raise ValueError(
- "Each message in messages_list should be a list of dictionaries with two keys: 'role' and 'content'."
- f"But got: {messages}"
- )
- else:
- raise ValueError(
- f"The messages_list should be a list of lists of dictionaries, but it's {type(messages_list)}."
- )
- sa_tokens = self.process_proteins(proteins, protein_max_length)
- text_tokens = self.process_text(messages_list, text_max_length)
- return BatchFeature(
- data={
- "protein_input_ids": sa_tokens["input_ids"],
- "protein_attention_mask": sa_tokens["attention_mask"],
- "input_ids": text_tokens["input_ids"],
- "attention_mask": text_tokens["attention_mask"],
- }
- )
- def batch_decode(self, *args, **kwargs):
- return self.tokenizer.batch_decode(*args, **kwargs)
- def decode(self, *args, **kwargs):
- return self.tokenizer.decode(*args, **kwargs)
- def protein_batch_decode(self, *args, **kwargs):
- return self.protein_tokenizer.batch_decode(*args, **kwargs)
- def protein_decode(self, *args, **kwargs):
- return self.protein_tokenizer.decode(*args, **kwargs)
- __all__ = ["EvollaProcessor"]
|