utils.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. """Utility functions for batch processing."""
  2. import logging
  3. from typing import TYPE_CHECKING, Any, Union
  4. if TYPE_CHECKING:
  5. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  6. AnyTokenizer = Union["PreTrainedTokenizer", "PreTrainedTokenizerFast", Any]
  7. logger = logging.getLogger(__name__)
  8. def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
  9. """Get tokenizer with cached properties.
  10. This will patch the tokenizer object in place.
  11. By default, transformers will recompute multiple tokenizer properties
  12. each time they are called, leading to a significant slowdown. This
  13. function caches these properties for faster access.
  14. Args:
  15. tokenizer: The tokenizer object.
  16. Returns:
  17. The patched tokenizer object.
  18. """
  19. chat_template = getattr(tokenizer, "chat_template", None)
  20. # For VLM, the text tokenizer is wrapped by a processor.
  21. if hasattr(tokenizer, "tokenizer"):
  22. tokenizer = tokenizer.tokenizer
  23. # Some VLM's tokenizer has chat_template attribute (e.g. Qwen/Qwen2-VL-7B-Instruct),
  24. # however some other VLM's tokenizer does not have chat_template attribute (e.g.
  25. # mistral-community/pixtral-12b). Therefore, we cache the processor's chat_template.
  26. if chat_template is None:
  27. chat_template = getattr(tokenizer, "chat_template", None)
  28. tokenizer_all_special_ids = set(tokenizer.all_special_ids)
  29. tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
  30. # all_special_tokens_extended is removed in transformers v5, used in latest
  31. # SGLang version. We require this SGLang version bc it's ABI compatible with
  32. # PyTorch 2.9, which is installed by vLLM.
  33. # TODO(seiji) remove the attribute completely once vLLM moves to transformers v5.
  34. tokenizer_all_special_tokens_extended = getattr(
  35. tokenizer, "all_special_tokens_extended", None
  36. )
  37. tokenizer_len = len(tokenizer)
  38. class CachedTokenizer(tokenizer.__class__): # type: ignore
  39. @property
  40. def all_special_ids(self):
  41. return tokenizer_all_special_ids
  42. @property
  43. def all_special_tokens(self):
  44. return tokenizer_all_special_tokens
  45. @property
  46. def all_special_tokens_extended(self):
  47. return tokenizer_all_special_tokens_extended
  48. @property
  49. def chat_template(self):
  50. return chat_template
  51. def __len__(self):
  52. return tokenizer_len
  53. CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
  54. tokenizer.__class__ = CachedTokenizer
  55. return tokenizer