| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- """Utility functions for batch processing."""
- import logging
- from typing import TYPE_CHECKING, Any, Union
- if TYPE_CHECKING:
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
- AnyTokenizer = Union["PreTrainedTokenizer", "PreTrainedTokenizerFast", Any]
- logger = logging.getLogger(__name__)
- def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
- """Get tokenizer with cached properties.
- This will patch the tokenizer object in place.
- By default, transformers will recompute multiple tokenizer properties
- each time they are called, leading to a significant slowdown. This
- function caches these properties for faster access.
- Args:
- tokenizer: The tokenizer object.
- Returns:
- The patched tokenizer object.
- """
- chat_template = getattr(tokenizer, "chat_template", None)
- # For VLM, the text tokenizer is wrapped by a processor.
- if hasattr(tokenizer, "tokenizer"):
- tokenizer = tokenizer.tokenizer
- # Some VLM's tokenizer has chat_template attribute (e.g. Qwen/Qwen2-VL-7B-Instruct),
- # however some other VLM's tokenizer does not have chat_template attribute (e.g.
- # mistral-community/pixtral-12b). Therefore, we cache the processor's chat_template.
- if chat_template is None:
- chat_template = getattr(tokenizer, "chat_template", None)
- tokenizer_all_special_ids = set(tokenizer.all_special_ids)
- tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
- # all_special_tokens_extended is removed in transformers v5, used in latest
- # SGLang version. We require this SGLang version bc it's ABI compatible with
- # PyTorch 2.9, which is installed by vLLM.
- # TODO(seiji) remove the attribute completely once vLLM moves to transformers v5.
- tokenizer_all_special_tokens_extended = getattr(
- tokenizer, "all_special_tokens_extended", None
- )
- tokenizer_len = len(tokenizer)
- class CachedTokenizer(tokenizer.__class__): # type: ignore
- @property
- def all_special_ids(self):
- return tokenizer_all_special_ids
- @property
- def all_special_tokens(self):
- return tokenizer_all_special_tokens
- @property
- def all_special_tokens_extended(self):
- return tokenizer_all_special_tokens_extended
- @property
- def chat_template(self):
- return chat_template
- def __len__(self):
- return tokenizer_len
- CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
- tokenizer.__class__ = CachedTokenizer
- return tokenizer
|