| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- import asyncio
- import os
- import signal
- import traceback
- from typing import Optional
- import typer
- from ...utils import ANSI
- from ._cli_hacks import _async_prompt, _patch_anyio_open_process
- from .agent import Agent
- from .utils import _load_agent_config
- app = typer.Typer(
- rich_markup_mode="rich",
- help="A squad of lightweight composable AI applications built on Hugging Face's Inference Client and MCP stack.",
- )
- run_cli = typer.Typer(
- name="run",
- help="Run the Agent in the CLI",
- invoke_without_command=True,
- )
- app.add_typer(run_cli, name="run")
- async def run_agent(
- agent_path: Optional[str],
- ) -> None:
- """
- Tiny Agent loop.
- Args:
- agent_path (`str`, *optional*):
- Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` or `AGENTS.md` file or a built-in agent stored in a Hugging Face dataset.
- """
- _patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C
- config, prompt = _load_agent_config(agent_path)
- inputs = config.get("inputs", [])
- servers = config.get("servers", [])
- abort_event = asyncio.Event()
- exit_event = asyncio.Event()
- first_sigint = True
- loop = asyncio.get_running_loop()
- original_sigint_handler = signal.getsignal(signal.SIGINT)
- def _sigint_handler() -> None:
- nonlocal first_sigint
- if first_sigint:
- first_sigint = False
- abort_event.set()
- print(ANSI.red("\nInterrupted. Press Ctrl+C again to quit."), flush=True)
- return
- print(ANSI.red("\nExiting..."), flush=True)
- exit_event.set()
- try:
- sigint_registered_in_loop = False
- try:
- loop.add_signal_handler(signal.SIGINT, _sigint_handler)
- sigint_registered_in_loop = True
- except (AttributeError, NotImplementedError):
- # Windows (or any loop that doesn't support it) : fall back to sync
- signal.signal(signal.SIGINT, lambda *_: _sigint_handler())
- # Handle inputs (i.e. env variables injection)
- resolved_inputs: dict[str, str] = {}
- if len(inputs) > 0:
- print(
- ANSI.bold(
- ANSI.blue(
- "Some initial inputs are required by the agent. "
- "Please provide a value or leave empty to load from env."
- )
- )
- )
- for input_item in inputs:
- input_id = input_item["id"]
- description = input_item["description"]
- env_special_value = f"${{input:{input_id}}}"
- # Check if the input is used by any server or as an apiKey
- input_usages = set()
- for server in servers:
- # Check stdio's "env" and http/sse's "headers" mappings
- env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
- for key, value in env_or_headers.items():
- if env_special_value in value:
- input_usages.add(key)
- raw_api_key = config.get("apiKey")
- if isinstance(raw_api_key, str) and env_special_value in raw_api_key:
- input_usages.add("apiKey")
- if not input_usages:
- print(
- ANSI.yellow(
- f"Input '{input_id}' defined in config but not used by any server or as an API key."
- " Skipping."
- )
- )
- continue
- # Prompt user for input
- env_variable_key = input_id.replace("-", "_").upper()
- print(
- ANSI.blue(f" • {input_id}") + f": {description}. (default: load from {env_variable_key}).",
- end=" ",
- )
- user_input = (await _async_prompt(exit_event=exit_event)).strip()
- if exit_event.is_set():
- return
- # Fallback to environment variable when user left blank
- final_value = user_input
- if not final_value:
- final_value = os.getenv(env_variable_key, "")
- if final_value:
- print(ANSI.green(f"Value successfully loaded from '{env_variable_key}'"))
- else:
- print(
- ANSI.yellow(
- f"No value found for '{env_variable_key}' in environment variables. Continuing."
- )
- )
- resolved_inputs[input_id] = final_value
- # Inject resolved value (can be empty) into stdio's env or http/sse's headers
- for server in servers:
- env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
- for key, value in env_or_headers.items():
- if env_special_value in value:
- env_or_headers[key] = env_or_headers[key].replace(env_special_value, final_value)
- print()
- raw_api_key = config.get("apiKey")
- if isinstance(raw_api_key, str):
- substituted_api_key = raw_api_key
- for input_id, val in resolved_inputs.items():
- substituted_api_key = substituted_api_key.replace(f"${{input:{input_id}}}", val)
- config["apiKey"] = substituted_api_key
- # Main agent loop
- async with Agent(
- provider=config.get("provider"), # type: ignore
- model=config.get("model"),
- base_url=config.get("endpointUrl"), # type: ignore[arg-type]
- api_key=config.get("apiKey"),
- servers=servers, # type: ignore[arg-type]
- prompt=prompt,
- ) as agent:
- await agent.load_tools()
- print(ANSI.bold(ANSI.blue("Agent loaded with {} tools:".format(len(agent.available_tools)))))
- for t in agent.available_tools:
- print(ANSI.blue(f" • {t.function.name}"))
- while True:
- abort_event.clear()
- # Check if we should exit
- if exit_event.is_set():
- return
- try:
- user_input = await _async_prompt(exit_event=exit_event)
- first_sigint = True
- except EOFError:
- print(ANSI.red("\nEOF received, exiting."), flush=True)
- break
- except KeyboardInterrupt:
- if not first_sigint and abort_event.is_set():
- continue
- else:
- print(ANSI.red("\nKeyboard interrupt during input processing."), flush=True)
- break
- try:
- async for chunk in agent.run(user_input, abort_event=abort_event):
- if abort_event.is_set() and not first_sigint:
- break
- if exit_event.is_set():
- return
- if hasattr(chunk, "choices"):
- delta = chunk.choices[0].delta
- if delta.content:
- print(delta.content, end="", flush=True)
- if delta.tool_calls:
- for call in delta.tool_calls:
- if call.id:
- print(f"<Tool {call.id}>", end="")
- if call.function.name:
- print(f"{call.function.name}", end=" ")
- if call.function.arguments:
- print(f"{call.function.arguments}", end="")
- else:
- print(
- ANSI.green(f"\n\nTool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}\n"),
- flush=True,
- )
- print()
- except Exception as e:
- tb_str = traceback.format_exc()
- print(ANSI.red(f"\nError during agent run: {e}\n{tb_str}"), flush=True)
- first_sigint = True # Allow graceful interrupt for the next command
- except Exception as e:
- tb_str = traceback.format_exc()
- print(ANSI.red(f"\nAn unexpected error occurred: {e}\n{tb_str}"), flush=True)
- raise e
- finally:
- if sigint_registered_in_loop:
- try:
- loop.remove_signal_handler(signal.SIGINT)
- except (AttributeError, NotImplementedError):
- pass
- else:
- signal.signal(signal.SIGINT, original_sigint_handler)
- @run_cli.callback()
- def run(
- path: Optional[str] = typer.Argument(
- None,
- help=(
- "Path to a local folder containing an agent.json file or a built-in agent "
- "stored in the 'tiny-agents/tiny-agents' Hugging Face dataset "
- "(https://huggingface.co/datasets/tiny-agents/tiny-agents)"
- ),
- show_default=False,
- ),
- ):
- try:
- asyncio.run(run_agent(path))
- except KeyboardInterrupt:
- print(ANSI.red("\nApplication terminated by KeyboardInterrupt."), flush=True)
- raise typer.Exit(code=130)
- except Exception as e:
- print(ANSI.red(f"\nAn unexpected error occurred: {e}"), flush=True)
- raise e
- if __name__ == "__main__":
- app()
|