cli.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import asyncio
  2. import os
  3. import signal
  4. import traceback
  5. from typing import Optional
  6. import typer
  7. from ...utils import ANSI
  8. from ._cli_hacks import _async_prompt, _patch_anyio_open_process
  9. from .agent import Agent
  10. from .utils import _load_agent_config
  11. app = typer.Typer(
  12. rich_markup_mode="rich",
  13. help="A squad of lightweight composable AI applications built on Hugging Face's Inference Client and MCP stack.",
  14. )
  15. run_cli = typer.Typer(
  16. name="run",
  17. help="Run the Agent in the CLI",
  18. invoke_without_command=True,
  19. )
  20. app.add_typer(run_cli, name="run")
  21. async def run_agent(
  22. agent_path: Optional[str],
  23. ) -> None:
  24. """
  25. Tiny Agent loop.
  26. Args:
  27. agent_path (`str`, *optional*):
  28. Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` or `AGENTS.md` file or a built-in agent stored in a Hugging Face dataset.
  29. """
  30. _patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C
  31. config, prompt = _load_agent_config(agent_path)
  32. inputs = config.get("inputs", [])
  33. servers = config.get("servers", [])
  34. abort_event = asyncio.Event()
  35. exit_event = asyncio.Event()
  36. first_sigint = True
  37. loop = asyncio.get_running_loop()
  38. original_sigint_handler = signal.getsignal(signal.SIGINT)
  39. def _sigint_handler() -> None:
  40. nonlocal first_sigint
  41. if first_sigint:
  42. first_sigint = False
  43. abort_event.set()
  44. print(ANSI.red("\nInterrupted. Press Ctrl+C again to quit."), flush=True)
  45. return
  46. print(ANSI.red("\nExiting..."), flush=True)
  47. exit_event.set()
  48. try:
  49. sigint_registered_in_loop = False
  50. try:
  51. loop.add_signal_handler(signal.SIGINT, _sigint_handler)
  52. sigint_registered_in_loop = True
  53. except (AttributeError, NotImplementedError):
  54. # Windows (or any loop that doesn't support it) : fall back to sync
  55. signal.signal(signal.SIGINT, lambda *_: _sigint_handler())
  56. # Handle inputs (i.e. env variables injection)
  57. resolved_inputs: dict[str, str] = {}
  58. if len(inputs) > 0:
  59. print(
  60. ANSI.bold(
  61. ANSI.blue(
  62. "Some initial inputs are required by the agent. "
  63. "Please provide a value or leave empty to load from env."
  64. )
  65. )
  66. )
  67. for input_item in inputs:
  68. input_id = input_item["id"]
  69. description = input_item["description"]
  70. env_special_value = f"${{input:{input_id}}}"
  71. # Check if the input is used by any server or as an apiKey
  72. input_usages = set()
  73. for server in servers:
  74. # Check stdio's "env" and http/sse's "headers" mappings
  75. env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
  76. for key, value in env_or_headers.items():
  77. if env_special_value in value:
  78. input_usages.add(key)
  79. raw_api_key = config.get("apiKey")
  80. if isinstance(raw_api_key, str) and env_special_value in raw_api_key:
  81. input_usages.add("apiKey")
  82. if not input_usages:
  83. print(
  84. ANSI.yellow(
  85. f"Input '{input_id}' defined in config but not used by any server or as an API key."
  86. " Skipping."
  87. )
  88. )
  89. continue
  90. # Prompt user for input
  91. env_variable_key = input_id.replace("-", "_").upper()
  92. print(
  93. ANSI.blue(f" • {input_id}") + f": {description}. (default: load from {env_variable_key}).",
  94. end=" ",
  95. )
  96. user_input = (await _async_prompt(exit_event=exit_event)).strip()
  97. if exit_event.is_set():
  98. return
  99. # Fallback to environment variable when user left blank
  100. final_value = user_input
  101. if not final_value:
  102. final_value = os.getenv(env_variable_key, "")
  103. if final_value:
  104. print(ANSI.green(f"Value successfully loaded from '{env_variable_key}'"))
  105. else:
  106. print(
  107. ANSI.yellow(
  108. f"No value found for '{env_variable_key}' in environment variables. Continuing."
  109. )
  110. )
  111. resolved_inputs[input_id] = final_value
  112. # Inject resolved value (can be empty) into stdio's env or http/sse's headers
  113. for server in servers:
  114. env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
  115. for key, value in env_or_headers.items():
  116. if env_special_value in value:
  117. env_or_headers[key] = env_or_headers[key].replace(env_special_value, final_value)
  118. print()
  119. raw_api_key = config.get("apiKey")
  120. if isinstance(raw_api_key, str):
  121. substituted_api_key = raw_api_key
  122. for input_id, val in resolved_inputs.items():
  123. substituted_api_key = substituted_api_key.replace(f"${{input:{input_id}}}", val)
  124. config["apiKey"] = substituted_api_key
  125. # Main agent loop
  126. async with Agent(
  127. provider=config.get("provider"), # type: ignore
  128. model=config.get("model"),
  129. base_url=config.get("endpointUrl"), # type: ignore[arg-type]
  130. api_key=config.get("apiKey"),
  131. servers=servers, # type: ignore[arg-type]
  132. prompt=prompt,
  133. ) as agent:
  134. await agent.load_tools()
  135. print(ANSI.bold(ANSI.blue("Agent loaded with {} tools:".format(len(agent.available_tools)))))
  136. for t in agent.available_tools:
  137. print(ANSI.blue(f" • {t.function.name}"))
  138. while True:
  139. abort_event.clear()
  140. # Check if we should exit
  141. if exit_event.is_set():
  142. return
  143. try:
  144. user_input = await _async_prompt(exit_event=exit_event)
  145. first_sigint = True
  146. except EOFError:
  147. print(ANSI.red("\nEOF received, exiting."), flush=True)
  148. break
  149. except KeyboardInterrupt:
  150. if not first_sigint and abort_event.is_set():
  151. continue
  152. else:
  153. print(ANSI.red("\nKeyboard interrupt during input processing."), flush=True)
  154. break
  155. try:
  156. async for chunk in agent.run(user_input, abort_event=abort_event):
  157. if abort_event.is_set() and not first_sigint:
  158. break
  159. if exit_event.is_set():
  160. return
  161. if hasattr(chunk, "choices"):
  162. delta = chunk.choices[0].delta
  163. if delta.content:
  164. print(delta.content, end="", flush=True)
  165. if delta.tool_calls:
  166. for call in delta.tool_calls:
  167. if call.id:
  168. print(f"<Tool {call.id}>", end="")
  169. if call.function.name:
  170. print(f"{call.function.name}", end=" ")
  171. if call.function.arguments:
  172. print(f"{call.function.arguments}", end="")
  173. else:
  174. print(
  175. ANSI.green(f"\n\nTool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}\n"),
  176. flush=True,
  177. )
  178. print()
  179. except Exception as e:
  180. tb_str = traceback.format_exc()
  181. print(ANSI.red(f"\nError during agent run: {e}\n{tb_str}"), flush=True)
  182. first_sigint = True # Allow graceful interrupt for the next command
  183. except Exception as e:
  184. tb_str = traceback.format_exc()
  185. print(ANSI.red(f"\nAn unexpected error occurred: {e}\n{tb_str}"), flush=True)
  186. raise e
  187. finally:
  188. if sigint_registered_in_loop:
  189. try:
  190. loop.remove_signal_handler(signal.SIGINT)
  191. except (AttributeError, NotImplementedError):
  192. pass
  193. else:
  194. signal.signal(signal.SIGINT, original_sigint_handler)
  195. @run_cli.callback()
  196. def run(
  197. path: Optional[str] = typer.Argument(
  198. None,
  199. help=(
  200. "Path to a local folder containing an agent.json file or a built-in agent "
  201. "stored in the 'tiny-agents/tiny-agents' Hugging Face dataset "
  202. "(https://huggingface.co/datasets/tiny-agents/tiny-agents)"
  203. ),
  204. show_default=False,
  205. ),
  206. ):
  207. try:
  208. asyncio.run(run_agent(path))
  209. except KeyboardInterrupt:
  210. print(ANSI.red("\nApplication terminated by KeyboardInterrupt."), flush=True)
  211. raise typer.Exit(code=130)
  212. except Exception as e:
  213. print(ANSI.red(f"\nAn unexpected error occurred: {e}"), flush=True)
  214. raise e
  215. if __name__ == "__main__":
  216. app()