| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518 |
- import time
- import warnings
- from abc import ABC
- from collections import OrderedDict
- from copy import deepcopy
- import numpy as np
- import torch
- from torch.nn import functional as F
- from ..tokenization_utils_base import PreTrainedTokenizerBase
- from ..utils import add_start_docstrings, logging
- logger = logging.get_logger(__name__)
- # We maintain a module-level cache of the embedding vectors for the stop string criterion
- # because they are slow to compute
- STOP_STRING_EMBEDDING_CACHE = OrderedDict()
- STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
- Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
- or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
- make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
- kwargs (`dict[str, Any]`, *optional*):
- Additional stopping criteria specific kwargs.
- Return:
- `torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`):
- `True` indicates we stop generation for a particular row.
- `False` indicates we should continue.
- """
- class StoppingCriteria(ABC):
- """Abstract base class for all stopping criteria that can be applied during generation.
- If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
- output_scores=True` to `generate`.
- """
- @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
- raise NotImplementedError("StoppingCriteria needs to be subclassed")
- class MaxLengthCriteria(StoppingCriteria):
- """
- This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
- in mind for decoder-only type of transformers, this will include the initial prompted tokens.
- Args:
- max_length (`int`):
- The maximum length that the output sequence can have in number of tokens.
- max_position_embeddings (`int`, *optional*):
- The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
- """
- def __init__(self, max_length: int, max_position_embeddings: int | None = None):
- self.max_length = max_length
- self.max_position_embeddings = max_position_embeddings
- @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
- cur_len = input_ids.shape[1]
- is_done = cur_len >= self.max_length
- if self.max_position_embeddings is not None and not is_done and cur_len > self.max_position_embeddings:
- logger.warning_once(
- "This is a friendly reminder - the current text generation call has exceeded the model's predefined "
- f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
- "exceptions, performance degradation, or nothing at all."
- )
- return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
- class MaxTimeCriteria(StoppingCriteria):
- """
- This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
- time will start being counted when you initialize this function. You can override this by passing an
- `initial_time`.
- Args:
- max_time (`float`):
- The maximum allowed time in seconds for the generation.
- initial_time (`float`, *optional*, defaults to `time.time()`):
- The start of the generation allowed time.
- """
- def __init__(self, max_time: float, initial_timestamp: float | None = None):
- self.max_time = max_time
- self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
- @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
- is_done = time.time() - self.initial_timestamp > self.max_time
- return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
- class StopStringCriteria(StoppingCriteria):
- """
- This class can be used to stop generation whenever specific string sequences are generated. It preprocesses
- the strings together with the tokenizer vocab to find positions where tokens can validly complete the stop strings.
- Generation is stopped as soon as a token is generated that completes any of the stop strings.
- We want to catch any instance in which the stop string would be present in the decoded output, which means
- we must also catch cases with "overhangs" off one or both ends. To make this more concrete, for the stop string
- "stop", any of the following token sequences would trigger the match:
- - ["st", "op"]
- - ["stop"]
- - ["st", "opera"]
- - ["sto", "pper"]
- - ["las", "topper"]
- - ["s", "to", "pped"]
- Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other
- words, these sequences will not trigger a match:
- - ["stop", "at"]
- - ["st", "op", "at"]
- - ["st", "opera", "tion"]
- The reason these are not a match is that the stop string does not overlap with the final token. If you can remove
- one or more tokens from the end of the sequence without destroying the stop string, then this criterion will not
- match that stop string. This is by design; because this check is run after each token is generated, we can't miss a
- valid stop string if one is generated, but we don't want to halt generation just because the stop string exists
- somewhere in the past input_ids.
- How is the match actually performed, though? We do it in quite a confusing way, because we want the entire match
- process to be compilable with Torch or XLA, which means we cannot use standard string methods. However, it is possible,
- with some work, to do string matching with pure tensor operations. We'll begin by describing the algorithm we use
- with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations.
- The key to the algorithm is an observation: Because the stop string must overlap with the end of the token sequence, we can start at
- the end of the sequence and work backwards. Specifically, we check that there is an overlap between the start of
- the final token and the end of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for
- some i > 0. If you look at the positive examples above, you'll see the last token in all of them fulfills this
- property:
- - ["st", "op"] (overlap is "op", overlap length == 2)
- - ["stop"] (overlap is "stop", overlap length == 4)
- - ["st", "opera"] (overlap is "op", overlap length == 2)
- - ["sto", "pper"] (overlap is "p", overlap length == 1)
- - ["las", "topper"] (overlap is "top", overlap length == 3)
- - ["s", "to", "pped"] (overlap is "p", overlap length == 1)
- It's impossible to construct a matching sequence that does not have this property (feel free to verify this
- yourself). However, although this overlap between the start of the final token and the end of the stop string is
- necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is
- consistent with the stop string.
- How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an
- overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already
- matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to"
- matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the
- remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so
- we have matched the entire stop string.
- How does it work when the tokens run off the start of the stop string, though? Let's consider the example of
- ["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore,
- the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to
- match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This
- matches the stop string, and so the entire string is matched.
- How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary
- information for all tokens! For every token, we compute:
- - Its overlap with the end of the stop string, if any
- - The positions inside the stop string where the token matches, including matches that run off the start.
- - The total length of the token
- For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions,
- and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position
- of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap,
- a single internal matching position of 3 (again counting from the end) and a length of 1.
- As long as we have this information, we can execute the algorithm above without any string comparison
- operations. We simply perform the following steps:
- - Check if the final token has an end-overlap with the start string
- - Continue backwards, keeping track of how much of the stop string we've matched so far
- - At each point, check if the next token has the current position as one of its valid positions
- - Continue until either a match fails, or we completely match the whole stop string
- Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match.
- We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again,
- counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched
- 3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length
- to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the
- entire stop string.
- In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have
- matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we
- allow tokens to match positions that run off the start of the stop string. We add its length to the position
- tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though -
- this also counts as a match of the stop string. We have matched the entire stop string.
- Args:
- tokenizer (`PreTrainedTokenizer`):
- The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences)
- stop_strings (`Union[str, list[str]]`):
- A list of strings that should end generation. If a string is passed, it will be treated like a
- list with a single element.
- Examples:
- ```python
- >>> from transformers import AutoModelForCausalLM, AutoTokenizer
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
- >>> model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
- >>> inputs = tokenizer("The biggest states in the USA by land area:", return_tensors="pt")
- >>> gen_out = model.generate(**inputs)
- >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
- The biggest states in the USA by land area:
- - Alaska
- - Texas
- - California
- >>> # Passing one or more stop strings will halt generation after those strings are emitted
- >>> # Note that generating with stop strings requires you to pass the tokenizer too
- >>> gen_out = model.generate(**inputs, stop_strings=["Texas"], tokenizer=tokenizer)
- >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
- The biggest states in the USA by land area:
- - Alaska
- - Texas
- ```
- """
- def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: str | list[str]):
- if isinstance(stop_strings, str):
- stop_strings = [stop_strings]
- self.stop_strings: tuple[str, ...] = tuple(stop_strings)
- vocab = tokenizer.get_vocab()
- token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
- self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
- token_list, token_indices, tokenizer
- )
- self.maximum_token_len = max(len(stop_string) for stop_string in self.stop_strings)
- self.num_stop_strings = len(self.stop_strings)
- self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)
- def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer):
- # We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
- if (token_list, token_indices, self.stop_strings) in STOP_STRING_EMBEDDING_CACHE:
- embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
- (token_list, token_indices, self.stop_strings)
- ]
- STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, self.stop_strings))
- else:
- clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
- embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
- clean_token_list, clean_token_indices, self.stop_strings
- )
- STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, self.stop_strings)] = (
- embedding_vec,
- max_valid_positions,
- max_valid_end_lens,
- )
- if len(STOP_STRING_EMBEDDING_CACHE) > 8:
- STOP_STRING_EMBEDDING_CACHE.popitem(last=False) # Pop from the start, the least recently used item
- return embedding_vec, max_valid_positions, max_valid_end_lens
- @staticmethod
- def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"):
- """
- This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string
- it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method
- tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix
- space addition/removal. To work around this, we add a static prefix to the start of the token, then remove
- it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string().
- """
- vocab = tokenizer.get_vocab()
- clean_token_list = []
- clean_token_indices = []
- sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"]
- tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base]
- for token, token_idx in vocab.items():
- token_string = tokenizer.convert_tokens_to_string(tokens_base + [token])
- token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :]
- clean_token_list.append(token_string)
- clean_token_indices.append(token_idx)
- return tuple(clean_token_list), tuple(clean_token_indices)
- @staticmethod
- def _stop_string_get_matching_positions(
- token_list, token_indices, stop_strings
- ) -> tuple[dict[str, dict[str, list[int]]], dict[str, dict[str, list[int]]]]:
- """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can
- validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the
- token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters
- from the end of the stop string that overlap with the start of the token, which can have more than one value.
- The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full
- explanation of what these values are for!"""
- token_valid_positions = {}
- token_end_overlaps = {}
- for stop_string in stop_strings:
- reversed_stop_string = stop_string[::-1]
- token_valid_positions[stop_string] = {}
- token_end_overlaps[stop_string] = {}
- for token, tok_idx in zip(token_list, token_indices):
- reversed_token = token[::-1]
- matching_positions = []
- possible_end_lengths = []
- for i in range(1 - len(token), len(stop_string)):
- if i < 0:
- tok = reversed_token[-i:]
- i = 0
- else:
- tok = reversed_token
- stop = reversed_stop_string[i : i + len(tok)]
- if tok.startswith(stop):
- if i == 0:
- possible_end_lengths.append(min(len(tok), len(stop)))
- else:
- matching_positions.append(i)
- if matching_positions:
- token_valid_positions[stop_string][tok_idx] = matching_positions
- if possible_end_lengths:
- token_end_overlaps[stop_string][tok_idx] = possible_end_lengths
- return token_valid_positions, token_end_overlaps
- @staticmethod
- def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> dict[str, torch.Tensor]:
- """This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs
- them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values
- that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!"""
- token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
- token_list, token_indices, stop_strings
- )
- all_valid_positions = [len(val) for positions in token_valid_positions.values() for val in positions.values()]
- # In some cases, tokens may have no valid internal positions (such as single-character stop strings), so
- # we need a fallback to handle this case
- max_valid_positions = max(all_valid_positions) if all_valid_positions else 1
- # There should always be at least one valid end_len, however, so no fallback needed here
- valid_end_lens = [len(val) for positions in token_end_overlaps.values() for val in positions.values()]
- if not valid_end_lens:
- raise ValueError(
- "Stop string preprocessing was unable to identify tokens matching one or more of the "
- "supplied stop string(s). This is most often caused by the stop "
- "strings containing unusual characters that are not in the tokenizer vocabulary."
- )
- max_valid_end_lens = max(valid_end_lens)
- vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1
- # We use +2 instead of +1 so we can have a dummy entry at the end. We will clamp all token values
- # over the max to this, ensuring they do not contribute to stop string matching.
- gather_vec = np.full((max(token_indices) + 2, vec_size), dtype=np.int32, fill_value=-1)
- for i, stop_string in enumerate(stop_strings):
- positions = token_valid_positions[stop_string]
- end_lens = token_end_overlaps[stop_string]
- # Since this is lots of very small assignments of lists, we build it with numpy rather
- # than torch for speed + simplicity, then convert to torch at the end
- for token_idx, valid_positions in positions.items():
- gather_vec[token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions)] = (
- valid_positions
- )
- for token_idx, possible_end_lens in end_lens.items():
- gather_vec[
- token_idx,
- max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions
- * len(stop_strings)
- + max_valid_end_lens * i
- + len(possible_end_lens),
- ] = possible_end_lens
- for token, token_idx in zip(token_list, token_indices):
- gather_vec[token_idx, -1] = len(token)
- gather_vec = torch.tensor(gather_vec, dtype=torch.int32)
- return gather_vec, max_valid_positions, max_valid_end_lens
- @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor:
- self.embedding_vec = self.embedding_vec.to(input_ids.device)
- self.target_lens = self.target_lens.to(input_ids.device)
- # The maximum length we need to consider is 1 token per character. Note that input_ids can also be
- # *shorter* than the global max, and the code below should be ready for that
- input_ids = input_ids[:, -self.maximum_token_len :]
- # Flip input_ids because we're only matching strings at the end of the generated sequence
- flipped_ids = torch.flip(input_ids, (1,))
- # Clip out-of-vocab values to the dummy value at the end of the embedding vector
- flipped_ids = torch.clamp(flipped_ids, max=self.embedding_vec.size(0) - 1)
- # Size of the vector of positions a single token can match
- max_valid_positions = self.max_valid_positions
- # The embedding vec contains the valid positions, end_lengths and total lengths for each token
- embedded = F.embedding(flipped_ids, self.embedding_vec)
- # Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit
- valid_positions = embedded[:, 1:, : max_valid_positions * self.num_stop_strings].unflatten(
- -1, (self.num_stop_strings, -1)
- )
- # end_lengths is the number of characters from the string, counting from the end, that the token
- # contains. It can have multiple values if the same token can overlap different end lengths
- end_lengths = embedded[:, :1, max_valid_positions * self.num_stop_strings : -1].unflatten(
- -1, (self.num_stop_strings, -1)
- )
- # Lengths is the total length of each token. Unlike the others, it always has a single value
- lengths = embedded[:, 1:, None, -1:] # Insert a dummy dimension for stop_strings even though lengths are const
- # Concatenate lengths onto each possible end_lengths value
- lengths = lengths.expand((-1, -1, end_lengths.shape[-2], end_lengths.shape[-1]))
- lengths_with_ends = torch.cat([end_lengths, lengths], dim=1)
- # cumsum() to get the number of matched characters in the stop string after each token
- cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x num_stop_strings x max_valid_end_lens
- # The calculation above assumes that all tokens are in valid positions. Now we mask the ones that are not.
- # First, tokens match the start of the string if they have a positive value in the end_lengths vector
- initial_match = end_lengths > 0
- # Tokens continue the string if the cumsum() so far is one of the valid positions for that token
- # Note that we're actually tracking one cumsum() for each possible end_length
- later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2)
- # The match vector is a boolean vector that indicates which positions have valid tokens
- match = torch.cat([initial_match, later_match], dim=1)
- # Once a single position does not match, all positions following that position are masked
- mask = (~match).cumsum(dim=1, dtype=torch.int32)
- mask = mask == 0
- # The string is matched if we reached a cumsum equal to or greater than the length of the string
- # before hitting the mask
- string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :]
- # We return a per-sample vector that is True if any stop string is matched for that sample
- return torch.any(string_matches, dim=-1)
- class EosTokenCriteria(StoppingCriteria):
- """
- This class can be used to stop generation whenever the "end-of-sequence" token is generated.
- By default, it uses the `model.generation_config.eos_token_id`.
- Args:
- eos_token_id (`Union[int, list[int], torch.Tensor]`):
- The id(s) of the *end-of-sequence* token.
- """
- def __init__(self, eos_token_id: int | list[int] | torch.Tensor):
- if not isinstance(eos_token_id, torch.Tensor):
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- eos_token_id = torch.tensor(eos_token_id)
- self.eos_token_id = eos_token_id
- @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
- self.eos_token_id = self.eos_token_id.to(input_ids.device)
- is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
- return is_done
- class ConfidenceCriteria(StoppingCriteria):
- """
- This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold
- `model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached.
- Args:
- assistant_confidence_threshold (`float`):
- The value of the threshold.
- """
- def __init__(self, assistant_confidence_threshold):
- self.assistant_confidence_threshold = assistant_confidence_threshold
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
- probs = scores[-1].softmax(-1)
- p = probs[0, input_ids[0, -1]].item()
- if p < self.assistant_confidence_threshold:
- return True
- return False
- class StoppingCriteriaList(list):
- @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
- is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool)
- for criteria in self:
- is_done = is_done | criteria(input_ids, scores, **kwargs)
- return is_done
- @property
- def max_length(self) -> int | None:
- for stopping_criterium in self:
- if isinstance(stopping_criterium, MaxLengthCriteria):
- return stopping_criterium.max_length
- return None
- def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
- stopping_max_length = stopping_criteria.max_length
- new_stopping_criteria = deepcopy(stopping_criteria)
- if stopping_max_length is not None and stopping_max_length != max_length:
- warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
- elif stopping_max_length is None:
- new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
- return new_stopping_criteria
|