| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import asyncio
- import json
- import os
- import platform
- import re
- import string
- import time
- from collections.abc import AsyncIterator
- from typing import Annotated, Any
- from urllib.parse import urljoin, urlparse
- import httpx
- import requests
- import typer
- import yaml
- from huggingface_hub import AsyncInferenceClient, ChatCompletionStreamOutput
- from transformers import GenerationConfig
- from transformers.utils import is_rich_available
- try:
- import readline # noqa importing this enables GNU readline capabilities
- except ImportError:
- # some platforms may not support readline: https://docs.python.org/3/library/readline.html
- pass
- if platform.system() != "Windows":
- import pwd
- if is_rich_available():
- from rich import filesize
- from rich.console import Console
- from rich.live import Live
- from rich.markdown import Markdown
- from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn, TimeElapsedColumn
- from rich.text import Text
- DEFAULT_HTTP_ENDPOINT = {"hostname": "localhost", "port": 8000}
- ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
- ALLOWED_VALUE_CHARS = set(
- string.ascii_letters + string.digits + string.whitespace + r".!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~"
- )
- DEFAULT_EXAMPLES = {
- "llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
- "code": {
- "text": (
- "Write a Python function that integrates any Python function f(x) numerically over an arbitrary "
- "interval [x_start, x_end]."
- ),
- },
- "helicopter": {"text": "How many helicopters can a human eat in one sitting?"},
- "numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
- "birds": {"text": "Why aren't birds real?"},
- "socks": {"text": "Why is it important to eat socks after meditating?"},
- "numbers2": {"text": "Which number is larger, 9.9 or 9.11?"},
- }
- # Printed at the start of a chat session
- HELP_STRING_MINIMAL = """
- **TRANSFORMERS CHAT INTERFACE**
- Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
- - **!help**: shows all available commands (set generation settings, save chat, etc.)
- - **!status**: shows the current status of the model and generation settings
- - **!clear**: clears the current conversation and starts a new one
- - **!exit**: closes the interface
- """
- # Printed when the user types `help` in the chat session
- HELP_STRING = f"""
- **TRANSFORMERS CHAT INTERFACE HELP**
- Full command list:
- - **!help**: shows this help message
- - **!clear**: clears the current conversation and starts a new one
- - **!status**: shows the current status of the model and generation settings
- - **!example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input.
- Available example names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
- - **!set {{ARG_1}}={{VALUE_1}} {{ARG_2}}={{VALUE_2}}** ...: changes the system prompt or generation settings (multiple
- settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
- If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
- - **!save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
- `./chat_history/{{MODEL_ID}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
- - **!exit**: closes the interface
- """
- class RichInterface:
- def __init__(self, model_id: str, user_id: str, base_url: str):
- self._console = Console()
- self.model_id = model_id
- self.user_id = user_id
- self.base_url = base_url
- async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, str | Any | None]:
- self._console.print(f"[bold blue]<{self.model_id}>:")
- with Live(console=self._console, refresh_per_second=4) as live:
- text = ""
- completion_tokens = 0
- start_time = time.time()
- finish_reason: str | None = None
- async for token in await stream:
- outputs = token.choices[0].delta.content
- finish_reason = getattr(token.choices[0], "finish_reason", finish_reason)
- usage = getattr(token, "usage", None)
- if usage is not None:
- completion_tokens = getattr(usage, "completion_tokens", completion_tokens)
- if not outputs:
- continue
- # Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
- # It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
- outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
- text += outputs
- # Render the accumulated text as Markdown
- # NOTE: this is a workaround for the rendering "unstandard markdown"
- # in rich. The chatbots output treat "\n" as a new line for
- # better compatibility with real-world text. However, rendering
- # in markdown would break the format. It is because standard markdown
- # treat a single "\n" in normal text as a space.
- # Our workaround is adding two spaces at the end of each line.
- # This is not a perfect solution, as it would
- # introduce trailing spaces (only) in code block, but it works well
- # especially for console output, because in general the console does not
- # care about trailing spaces.
- lines = []
- for line in text.splitlines():
- lines.append(line)
- if line.startswith("```"):
- # Code block marker - do not add trailing spaces, as it would
- # break the syntax highlighting
- lines.append("\n")
- else:
- lines.append(" \n")
- markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
- # Update the Live console output
- live.update(markdown, refresh=True)
- elapsed = time.time() - start_time
- if elapsed > 0 and completion_tokens > 0:
- tok_per_sec = completion_tokens / elapsed
- self._console.print()
- self._console.print(f"[dim]{completion_tokens} tokens in {elapsed:.1f}s ({tok_per_sec:.1f} tok/s)[/dim]")
- self._console.print()
- return text, finish_reason
- def input(self) -> str:
- """Gets user input from the console."""
- input = self._console.input(f"[bold red]<{self.user_id}>:\n")
- self._console.print()
- return input
- def clear(self):
- """Clears the console."""
- self._console.clear()
- def print_user_message(self, text: str):
- """Prints a user message to the console."""
- self._console.print(f"[bold red]<{self.user_id}>:[/ bold red]\n{text}")
- self._console.print()
- def print_color(self, text: str, color: str):
- """Prints text in a given color to the console."""
- self._console.print(f"[bold {color}]{text}")
- self._console.print()
- def confirm(self, message: str, default: bool = False) -> bool:
- """Displays a yes/no prompt to the user, returning True for confirmation."""
- default_hint = "Y/n" if default else "y/N"
- response = self._console.input(f"[bold yellow]{message} ({default_hint}): ")
- self._console.print()
- response = response.strip().lower()
- if not response:
- return default
- return response in {"y", "yes"}
- def print_help(self, minimal: bool = False):
- """Prints the help message to the console."""
- self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
- self._console.print()
- def print_model_load(self, model: str):
- response = requests.post(f"{self.base_url.rstrip('/')}/load_model", json={"model": model}, stream=True)
- response.raise_for_status()
- class StatsColumn(ProgressColumn):
- def render(self, task):
- if not task.total:
- return Text("")
- if task.fields.get("unit") == "bytes":
- done = filesize.decimal(int(task.completed))
- tot = filesize.decimal(int(task.total))
- speed = f" {filesize.decimal(int(task.speed))}/s" if task.speed else ""
- if task.time_remaining is not None:
- eta = f" {int(task.time_remaining // 60)}:{int(task.time_remaining % 60):02d}"
- else:
- eta = ""
- return Text(f"{done}/{tot}{speed}{eta}", style="progress.download")
- return Text(f"{int(task.completed)}/{int(task.total)}")
- stage_labels = {
- "processor": "Loading processor",
- "config": "Loading config",
- "download": "Downloading files",
- "weights": "Loading into memory",
- }
- # Include the model name prefix in descriptions only when the terminal is wide enough.
- # The bar, stats, and elapsed columns need ~70 chars; the model prefix needs len(model)+5.
- show_model_prefix = self._console.width >= len(model) + 5 + 70
- def _label(stage_key):
- stage_text = stage_labels.get(stage_key, stage_key)
- if show_model_prefix:
- return f"{model} → {stage_text}"
- return stage_text
- progress = Progress(
- TextColumn("[bold]{task.description}"),
- BarColumn(bar_width=40),
- StatsColumn(),
- TimeElapsedColumn(),
- console=self._console,
- )
- task_id = progress.add_task(_label("processor"), total=None)
- cached = False
- with Live(progress, console=self._console, transient=True):
- for line in response.iter_lines():
- if not line or not line.startswith(b"data: "):
- continue
- event = json.loads(line[6:])
- status = event.get("status")
- if status == "ready":
- cached = event.get("cached", False)
- break
- if status == "error":
- raise RuntimeError(event.get("message", "Unknown error"))
- if status == "loading":
- stage = event.get("stage")
- prog = event.get("progress")
- label = _label(stage)
- if prog:
- unit = "bytes" if stage == "download" else "items"
- progress.update(
- task_id, description=label, completed=prog["current"], total=prog.get("total"), unit=unit
- )
- else:
- progress.update(task_id, description=label, completed=0, total=None)
- if cached:
- self._console.print(Markdown(f"_*{model} was already loaded.*_"))
- else:
- self._console.print(Markdown(f"_*{model} is warm.*_"))
- self._console.print()
- def print_status(self, config: GenerationConfig):
- """Prints the status of the model and generation settings to the console."""
- self._console.print(f"[bold blue]Model: {self.model_id}\n")
- self._console.print(f"[bold blue]{config}")
- self._console.print()
- class Chat:
- """Chat with a model from the command line."""
- # Defining a class to help with internal state but in practice it's just a method to call
- # TODO: refactor into a proper module with helpers + 1 main method
- def __init__(
- self,
- model_id: Annotated[str, typer.Argument(help="ID of the model to use (e.g. 'HuggingFaceTB/SmolLM3-3B').")],
- base_url: Annotated[
- str | None, typer.Argument(help="Base url to connect to (e.g. http://localhost:8000/v1).")
- ] = f"http://{DEFAULT_HTTP_ENDPOINT['hostname']}:{DEFAULT_HTTP_ENDPOINT['port']}",
- generate_flags: Annotated[
- list[str] | None,
- typer.Argument(
- help=(
- "Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
- "and lists of integers, more advanced parameterization should be set through --generation-config. "
- "Example: `transformers chat <base_url> <model_id> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
- "If you're a new user, check this basic flag guide: "
- "https://huggingface.co/docs/transformers/llm_tutorial#common-options"
- )
- ),
- ] = None,
- # General settings
- user: Annotated[
- str | None,
- typer.Option(help="Username to display in chat interface. Defaults to the current user's name."),
- ] = None,
- system_prompt: Annotated[str | None, typer.Option(help="System prompt.")] = None,
- save_folder: Annotated[str, typer.Option(help="Folder to save chat history.")] = "./chat_history/",
- examples_path: Annotated[str | None, typer.Option(help="Path to a yaml file with examples.")] = None,
- # Generation settings
- generation_config: Annotated[
- str | None,
- typer.Option(
- help="Path to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config."
- ),
- ] = None,
- ) -> None:
- """Chat with a model from the command line."""
- self.base_url = base_url
- parsed = urlparse(self.base_url)
- if parsed.hostname == DEFAULT_HTTP_ENDPOINT["hostname"] and parsed.port == DEFAULT_HTTP_ENDPOINT["port"]:
- self.check_health(self.base_url)
- self.model_id = model_id
- self.system_prompt = system_prompt
- self.save_folder = save_folder
- # Generation settings
- config = load_generation_config(generation_config)
- config.update(do_sample=True, max_new_tokens=256) # some default values
- config.update(**parse_generate_flags(generate_flags))
- self.config = config
- self.settings = {"base_url": base_url, "model_id": model_id, "config": self.config.to_dict()}
- # User settings
- self.user = user if user is not None else get_username()
- # Load examples
- if examples_path:
- with open(examples_path) as f:
- self.examples = yaml.safe_load(f)
- else:
- self.examples = DEFAULT_EXAMPLES
- # Check requirements
- if not is_rich_available():
- raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
- # Run chat session
- asyncio.run(self._inner_run())
- @staticmethod
- def check_health(url):
- health_url = urljoin(url + "/", "health")
- try:
- output = httpx.get(health_url)
- if output.status_code != 200:
- raise ValueError(
- f"The server running on {url} returned status code {output.status_code} on health check (/health)."
- )
- except httpx.ConnectError:
- raise ValueError(
- f"No server currently running on {url}. To run a local server, please run `transformers serve` in a"
- f"separate shell. Find more information here: https://huggingface.co/docs/transformers/serving"
- )
- return True
- def handle_non_exit_user_commands(
- self,
- user_input: str,
- interface: RichInterface,
- examples: dict[str, dict[str, str]],
- config: GenerationConfig,
- chat: list[dict],
- ) -> tuple[list[dict], GenerationConfig]:
- """
- Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
- generation config (e.g. set a new flag).
- """
- valid_command = True
- if user_input == "!clear":
- chat = new_chat_history(self.system_prompt)
- interface.clear()
- elif user_input == "!help":
- interface.print_help()
- elif user_input.startswith("!save") and len(user_input.split()) < 2:
- split_input = user_input.split()
- filename = (
- split_input[1]
- if len(split_input) == 2
- else os.path.join(self.save_folder, self.model_id, f"chat_{time.strftime('%Y-%m-%d_%H-%M-%S')}.json")
- )
- save_chat(filename=filename, chat=chat, settings=self.settings)
- interface.print_color(text=f"Chat saved to {filename}!", color="green")
- elif user_input.startswith("!set"):
- # splits the new args into a list of strings, each string being a `flag=value` pair (same format as
- # `generate_flags`)
- new_generate_flags = user_input[4:].strip()
- new_generate_flags = new_generate_flags.split()
- # sanity check: each member in the list must have an =
- for flag in new_generate_flags:
- if "=" not in flag:
- interface.print_color(
- text=(
- f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
- "`arg_1=value_1 arg_2=value_2 ...`."
- ),
- color="red",
- )
- break
- else:
- # Update config from user flags
- config.update(**parse_generate_flags(new_generate_flags))
- elif user_input.startswith("!example") and len(user_input.split()) == 2:
- example_name = user_input.split()[1]
- if example_name in examples:
- interface.clear()
- chat = []
- interface.print_user_message(examples[example_name]["text"])
- chat.append({"role": "user", "content": examples[example_name]["text"]})
- else:
- example_error = (
- f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
- )
- interface.print_color(text=example_error, color="red")
- elif user_input == "!status":
- interface.print_status(config=config)
- else:
- valid_command = False
- interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red")
- interface.print_help()
- return chat, valid_command, config
- async def _inner_run(self):
- interface = RichInterface(model_id=self.model_id, user_id=self.user, base_url=self.base_url)
- interface.clear()
- chat = new_chat_history(self.system_prompt)
- # Starts the session with a minimal help message at the top, so that a user doesn't get stuck
- interface.print_help(minimal=True)
- interface.print_model_load(self.model_id)
- config = self.config
- async with AsyncInferenceClient(base_url=self.base_url) as client:
- pending_user_input: str | None = None
- while True:
- try:
- if pending_user_input is not None:
- user_input = pending_user_input
- pending_user_input = None
- interface.print_user_message(user_input)
- else:
- user_input = interface.input()
- # User commands
- if user_input == "!exit":
- break
- elif user_input == "!clear":
- chat = new_chat_history(self.system_prompt)
- interface.clear()
- continue
- elif user_input == "!help":
- interface.print_help()
- continue
- elif user_input.startswith("!save") and len(user_input.split()) < 2:
- split_input = user_input.split()
- filename = (
- split_input[1]
- if len(split_input) == 2
- else os.path.join(
- self.save_folder, self.model_id, f"chat_{time.strftime('%Y-%m-%d_%H-%M-%S')}.json"
- )
- )
- save_chat(filename=filename, chat=chat, settings=self.settings)
- interface.print_color(text=f"Chat saved to {filename}!", color="green")
- continue
- elif user_input.startswith("!set"):
- # splits the new args into a list of strings, each string being a `flag=value` pair (same format as
- # `generate_flags`)
- new_generate_flags = user_input[4:].strip()
- new_generate_flags = new_generate_flags.split()
- # sanity check: each member in the list must have an =
- for flag in new_generate_flags:
- if "=" not in flag:
- interface.print_color(
- text=(
- f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
- "`arg_1=value_1 arg_2=value_2 ...`."
- ),
- color="red",
- )
- break
- else:
- # Update config from user flags
- config.update(**parse_generate_flags(new_generate_flags))
- continue
- elif user_input.startswith("!example") and len(user_input.split()) == 2:
- example_name = user_input.split()[1]
- if example_name in self.examples:
- interface.clear()
- chat = []
- interface.print_user_message(self.examples[example_name]["text"])
- chat.append({"role": "user", "content": self.examples[example_name]["text"]})
- else:
- example_error = f"Example {example_name} not found in list of available examples: {list(self.examples.keys())}."
- interface.print_color(text=example_error, color="red")
- elif user_input == "!status":
- interface.print_status(config=config)
- continue
- elif user_input.startswith("!"):
- interface.print_color(
- text=f"'{user_input}' is not a valid command. Showing help message.", color="red"
- )
- interface.print_help()
- continue
- else:
- chat.append({"role": "user", "content": user_input})
- extra_body = {
- "generation_config": config.to_json_string(),
- "model": self.model_id,
- }
- stream = client.chat_completion(
- chat,
- stream=True,
- model=self.model_id,
- extra_body=extra_body,
- )
- model_output, finish_reason = await interface.stream_output(stream)
- chat.append({"role": "assistant", "content": model_output})
- if finish_reason == "length":
- interface.print_color("Generation stopped after reaching the token limit.", "yellow")
- if interface.confirm("Continue generating?"):
- pending_user_input = "Please continue. Do not repeat text.”"
- continue
- except KeyboardInterrupt:
- break
- def load_generation_config(generation_config: str | None) -> GenerationConfig:
- if generation_config is None:
- return GenerationConfig()
- if ".json" in generation_config: # is a local file
- dirname = os.path.dirname(generation_config)
- filename = os.path.basename(generation_config)
- return GenerationConfig.from_pretrained(dirname, filename)
- else:
- return GenerationConfig.from_pretrained(generation_config)
- def parse_generate_flags(generate_flags: list[str] | None) -> dict:
- """Parses the generate flags from the user input into a dictionary of `generate` kwargs."""
- if generate_flags is None or len(generate_flags) == 0:
- return {}
- # Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed
- # into a json string if we:
- # 1. Add quotes around each flag name
- generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags}
- # 2. Handle types:
- # 2. a. booleans should be lowercase, None should be null
- generate_flags_as_dict = {
- k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items()
- }
- generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()}
- # 2. b. strings should be quoted
- def is_number(s: str) -> bool:
- # handle negative numbers
- s = s.removeprefix("-")
- return s.replace(".", "", 1).isdigit()
- generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
- # 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :)
- # We also mention in the help message that we only accept lists of ints for now.
- # 3. Join the result into a comma separated string
- generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()])
- # 4. Add the opening/closing brackets
- generate_flags_string = "{" + generate_flags_string + "}"
- # 5. Remove quotes around boolean/null and around lists
- generate_flags_string = generate_flags_string.replace('"null"', "null")
- generate_flags_string = generate_flags_string.replace('"true"', "true")
- generate_flags_string = generate_flags_string.replace('"false"', "false")
- generate_flags_string = generate_flags_string.replace('"[', "[")
- generate_flags_string = generate_flags_string.replace(']"', "]")
- # 6. Replace the `=` with `:`
- generate_flags_string = generate_flags_string.replace("=", ":")
- try:
- processed_generate_flags = json.loads(generate_flags_string)
- except json.JSONDecodeError:
- raise ValueError(
- "Failed to convert `generate_flags` into a valid JSON object."
- "\n`generate_flags` = {generate_flags}"
- "\nConverted JSON string = {generate_flags_string}"
- )
- return processed_generate_flags
- def new_chat_history(system_prompt: str | None = None) -> list[dict]:
- """Returns a new chat conversation."""
- return [{"role": "system", "content": system_prompt}] if system_prompt else []
- def save_chat(filename: str, chat: list[dict], settings: dict) -> str:
- """Saves the chat history to a file."""
- os.makedirs(os.path.dirname(filename), exist_ok=True)
- with open(filename, "w") as f:
- json.dump({"settings": settings, "chat_history": chat}, f, indent=4)
- return os.path.abspath(filename)
- def get_username() -> str:
- """Returns the username of the current user."""
- if platform.system() == "Windows":
- return os.getlogin()
- else:
- return pwd.getpwuid(os.getuid()).pw_name
- if __name__ == "__main__":
- Chat(model_id="meta-llama/Llama-3.2-3b-Instruct")
|