streamers.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # Copyright 2023 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import asyncio
  16. import sys
  17. from queue import Queue
  18. from typing import TYPE_CHECKING, Any, cast
  19. if TYPE_CHECKING:
  20. from ..tokenization_utils_base import PreTrainedTokenizerBase
  21. class BaseStreamer:
  22. """
  23. Base class from which `.generate()` streamers should inherit.
  24. """
  25. def put(self, value):
  26. """Function that is called by `.generate()` to push new tokens"""
  27. raise NotImplementedError()
  28. def end(self):
  29. """Function that is called by `.generate()` to signal the end of generation"""
  30. raise NotImplementedError()
  31. class TextStreamer(BaseStreamer):
  32. """
  33. Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
  34. <Tip warning={true}>
  35. The API for the streamer classes is still under development and may change in the future.
  36. </Tip>
  37. Parameters:
  38. tokenizer (`AutoTokenizer`):
  39. The tokenized used to decode the tokens.
  40. skip_prompt (`bool`, *optional*, defaults to `False`):
  41. Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
  42. decode_kwargs (`dict`, *optional*):
  43. Additional keyword arguments to pass to the tokenizer's `decode` method.
  44. Examples:
  45. ```python
  46. >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
  47. >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
  48. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  49. >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
  50. >>> streamer = TextStreamer(tok)
  51. >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
  52. >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
  53. An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
  54. ```
  55. """
  56. def __init__(self, tokenizer: PreTrainedTokenizerBase, skip_prompt: bool = False, **decode_kwargs: Any):
  57. self.tokenizer = tokenizer
  58. self.skip_prompt = skip_prompt
  59. self.decode_kwargs = decode_kwargs
  60. # variables used in the streaming process
  61. self.token_cache: list[int] = []
  62. self.print_len = 0
  63. self.next_tokens_are_prompt = True
  64. def put(self, value):
  65. """
  66. Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
  67. """
  68. if len(value.shape) > 1 and value.shape[0] > 1:
  69. raise ValueError("TextStreamer only supports batch size 1")
  70. elif len(value.shape) > 1:
  71. value = value[0]
  72. if self.skip_prompt and self.next_tokens_are_prompt:
  73. self.next_tokens_are_prompt = False
  74. return
  75. # Add the new token to the cache and decodes the entire thing.
  76. self.token_cache.extend(value.tolist())
  77. text = cast(str, self.tokenizer.decode(self.token_cache, **self.decode_kwargs))
  78. # After the symbol for a new line, we flush the cache.
  79. if text.endswith("\n"):
  80. printable_text = text[self.print_len :]
  81. self.token_cache = []
  82. self.print_len = 0
  83. # If the last token is a CJK character, we print the characters.
  84. elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
  85. printable_text = text[self.print_len :]
  86. self.print_len += len(printable_text)
  87. # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
  88. # which may change with the subsequent token -- there are probably smarter ways to do this!)
  89. else:
  90. printable_text = text[self.print_len : text.rfind(" ") + 1]
  91. self.print_len += len(printable_text)
  92. self.on_finalized_text(printable_text)
  93. def end(self):
  94. """Flushes any remaining cache and prints a newline to stdout."""
  95. # Flush the cache, if it exists
  96. if len(self.token_cache) > 0:
  97. text = cast(str, self.tokenizer.decode(self.token_cache, **self.decode_kwargs))
  98. printable_text = text[self.print_len :]
  99. self.token_cache = []
  100. self.print_len = 0
  101. else:
  102. printable_text = ""
  103. self.next_tokens_are_prompt = True
  104. self.on_finalized_text(printable_text, stream_end=True)
  105. def on_finalized_text(self, text: str, stream_end: bool = False):
  106. """Prints the new text to stdout. If the stream is ending, also prints a newline."""
  107. print(text, flush=True, end="" if not stream_end else None)
  108. def _is_chinese_char(self, cp):
  109. """Checks whether CP is the codepoint of a CJK character."""
  110. # This defines a "chinese character" as anything in the CJK Unicode block:
  111. # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
  112. #
  113. # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
  114. # despite its name. The modern Korean Hangul alphabet is a different block,
  115. # as is Japanese Hiragana and Katakana. Those alphabets are used to write
  116. # space-separated words, so they are not treated specially and handled
  117. # like the all of the other languages.
  118. if (
  119. (cp >= 0x4E00 and cp <= 0x9FFF)
  120. or (cp >= 0x3400 and cp <= 0x4DBF)
  121. or (cp >= 0x20000 and cp <= 0x2A6DF)
  122. or (cp >= 0x2A700 and cp <= 0x2B73F)
  123. or (cp >= 0x2B740 and cp <= 0x2B81F)
  124. or (cp >= 0x2B820 and cp <= 0x2CEAF)
  125. or (cp >= 0xF900 and cp <= 0xFAFF)
  126. or (cp >= 0x2F800 and cp <= 0x2FA1F)
  127. ):
  128. return True
  129. return False
  130. class TextIteratorStreamer(TextStreamer):
  131. """
  132. Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
  133. useful for applications that benefit from accessing the generated text in a non-blocking way (e.g. in an interactive
  134. Gradio demo).
  135. <Tip warning={true}>
  136. The API for the streamer classes is still under development and may change in the future.
  137. </Tip>
  138. Parameters:
  139. tokenizer (`AutoTokenizer`):
  140. The tokenized used to decode the tokens.
  141. skip_prompt (`bool`, *optional*, defaults to `False`):
  142. Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
  143. timeout (`float`, *optional*):
  144. The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
  145. in `.generate()`, when it is called in a separate thread.
  146. decode_kwargs (`dict`, *optional*):
  147. Additional keyword arguments to pass to the tokenizer's `decode` method.
  148. Examples:
  149. ```python
  150. >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
  151. >>> from threading import Thread
  152. >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
  153. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  154. >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
  155. >>> streamer = TextIteratorStreamer(tok)
  156. >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
  157. >>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
  158. >>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
  159. >>> thread.start()
  160. >>> generated_text = ""
  161. >>> for new_text in streamer:
  162. ... generated_text += new_text
  163. >>> generated_text
  164. 'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
  165. ```
  166. """
  167. def __init__(
  168. self,
  169. tokenizer: PreTrainedTokenizerBase,
  170. skip_prompt: bool = False,
  171. timeout: float | None = None,
  172. **decode_kwargs: Any,
  173. ):
  174. super().__init__(tokenizer, skip_prompt, **decode_kwargs)
  175. self.text_queue = Queue()
  176. self.stop_signal = None
  177. self.timeout = timeout
  178. def on_finalized_text(self, text: str, stream_end: bool = False):
  179. """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
  180. self.text_queue.put(text, timeout=self.timeout)
  181. if stream_end:
  182. self.text_queue.put(self.stop_signal, timeout=self.timeout)
  183. def __iter__(self):
  184. return self
  185. def __next__(self):
  186. value = self.text_queue.get(timeout=self.timeout)
  187. if value == self.stop_signal:
  188. raise StopIteration()
  189. else:
  190. return value
  191. class AsyncTextIteratorStreamer(TextStreamer):
  192. """
  193. Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator.
  194. This is useful for applications that benefit from accessing the generated text asynchronously (e.g. in an
  195. interactive Gradio demo).
  196. <Tip warning={true}>
  197. The API for the streamer classes is still under development and may change in the future.
  198. </Tip>
  199. Parameters:
  200. tokenizer (`AutoTokenizer`):
  201. The tokenized used to decode the tokens.
  202. skip_prompt (`bool`, *optional*, defaults to `False`):
  203. Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
  204. timeout (`float`, *optional*):
  205. The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
  206. in `.generate()`, when it is called in a separate thread.
  207. decode_kwargs (`dict`, *optional*):
  208. Additional keyword arguments to pass to the tokenizer's `decode` method.
  209. Raises:
  210. TimeoutError: If token generation time exceeds timeout value.
  211. Examples:
  212. ```python
  213. >>> from transformers import AutoModelForCausalLM, AutoTokenizer, AsyncTextIteratorStreamer
  214. >>> from threading import Thread
  215. >>> import asyncio
  216. >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
  217. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  218. >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
  219. >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
  220. >>> async def main():
  221. ... # Important: AsyncTextIteratorStreamer must be initialized inside a coroutine!
  222. ... streamer = AsyncTextIteratorStreamer(tok)
  223. ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
  224. ... thread = Thread(target=model.generate, kwargs=generation_kwargs)
  225. ... thread.start()
  226. ... generated_text = ""
  227. ... async for new_text in streamer:
  228. ... generated_text += new_text
  229. >>> print(generated_text)
  230. >>> asyncio.run(main())
  231. An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
  232. ```
  233. """
  234. def __init__(
  235. self,
  236. tokenizer: PreTrainedTokenizerBase,
  237. skip_prompt: bool = False,
  238. timeout: float | None = None,
  239. **decode_kwargs: Any,
  240. ):
  241. super().__init__(tokenizer, skip_prompt, **decode_kwargs)
  242. self.text_queue = asyncio.Queue()
  243. self.stop_signal = None
  244. self.timeout = timeout
  245. self.loop = asyncio.get_running_loop()
  246. timeout_context = getattr(asyncio, "timeout", None)
  247. self.has_asyncio_timeout = sys.version_info >= (3, 11) and callable(timeout_context)
  248. self.asyncio_timeout = timeout_context if self.has_asyncio_timeout else None
  249. def on_finalized_text(self, text: str, stream_end: bool = False):
  250. """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
  251. self.loop.call_soon_threadsafe(self.text_queue.put_nowait, text)
  252. if stream_end:
  253. self.loop.call_soon_threadsafe(self.text_queue.put_nowait, self.stop_signal)
  254. def __aiter__(self):
  255. return self
  256. async def __anext__(self):
  257. try:
  258. if self.has_asyncio_timeout and self.asyncio_timeout is not None:
  259. async with self.asyncio_timeout(self.timeout):
  260. value = await self.text_queue.get()
  261. else:
  262. value = await asyncio.wait_for(self.text_queue.get(), timeout=self.timeout)
  263. except asyncio.TimeoutError:
  264. raise TimeoutError()
  265. else:
  266. if value == self.stop_signal:
  267. raise StopAsyncIteration()
  268. else:
  269. return value