_webhooks_server.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # Copyright 2023-present, the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily."""
  15. import atexit
  16. import inspect
  17. import os
  18. from collections.abc import Callable
  19. from functools import wraps
  20. from typing import TYPE_CHECKING, Any, Optional
  21. from .utils import experimental, is_fastapi_available, is_gradio_available
  22. if TYPE_CHECKING:
  23. import gradio as gr
  24. from fastapi import Request
  25. if is_fastapi_available():
  26. from fastapi import FastAPI, Request
  27. from fastapi.responses import JSONResponse
  28. else:
  29. # Will fail at runtime if FastAPI is not available
  30. FastAPI = Request = JSONResponse = None # type: ignore
  31. _global_app: Optional["WebhooksServer"] = None
  32. _is_local = os.environ.get("SPACE_ID") is None
  33. @experimental
  34. class WebhooksServer:
  35. """
  36. The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks.
  37. These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to
  38. the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `launch` method has to be
  39. called to start the app.
  40. It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic
  41. model that contains all the information about the webhook event. The data will be parsed automatically for you.
  42. Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to set up your
  43. WebhooksServer and deploy it on a Space.
  44. > [!WARNING]
  45. > `WebhooksServer` is experimental. Its API is subject to change in the future.
  46. > [!WARNING]
  47. > You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`).
  48. Args:
  49. ui (`gradio.Blocks`, optional):
  50. A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions
  51. about the configured webhooks is created.
  52. webhook_secret (`str`, optional):
  53. A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as
  54. you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You
  55. can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the
  56. webhook endpoints are opened without any security.
  57. Example:
  58. ```python
  59. import gradio as gr
  60. from huggingface_hub import WebhooksServer, WebhookPayload
  61. with gr.Blocks() as ui:
  62. ...
  63. app = WebhooksServer(ui=ui, webhook_secret="my_secret_key")
  64. @app.add_webhook("/say_hello")
  65. async def hello(payload: WebhookPayload):
  66. return {"message": "hello"}
  67. app.launch()
  68. ```
  69. """
  70. def __new__(cls, *args, **kwargs) -> "WebhooksServer":
  71. if not is_gradio_available():
  72. raise ImportError(
  73. "You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`"
  74. " first."
  75. )
  76. if not is_fastapi_available():
  77. raise ImportError(
  78. "You must have `fastapi` installed to use `WebhooksServer`. Please run `pip install --upgrade fastapi`"
  79. " first."
  80. )
  81. return super().__new__(cls)
  82. def __init__(
  83. self,
  84. ui: Optional["gr.Blocks"] = None,
  85. webhook_secret: str | None = None,
  86. ) -> None:
  87. self._ui = ui
  88. self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET")
  89. self.registered_webhooks: dict[str, Callable] = {}
  90. _warn_on_empty_secret(self.webhook_secret)
  91. def add_webhook(self, path: str | None = None) -> Callable:
  92. """
  93. Decorator to add a webhook to the [`WebhooksServer`] server.
  94. Args:
  95. path (`str`, optional):
  96. The URL path to register the webhook function. If not provided, the function name will be used as the
  97. path. In any case, all webhooks are registered under `/webhooks`.
  98. Raises:
  99. ValueError: If the provided path is already registered as a webhook.
  100. Example:
  101. ```python
  102. from huggingface_hub import WebhooksServer, WebhookPayload
  103. app = WebhooksServer()
  104. @app.add_webhook
  105. async def trigger_training(payload: WebhookPayload):
  106. if payload.repo.type == "dataset" and payload.event.action == "update":
  107. # Trigger a training job if a dataset is updated
  108. ...
  109. app.launch()
  110. ```
  111. """
  112. # Usage: directly as decorator. Example: `@app.add_webhook`
  113. if callable(path):
  114. # If path is a function, it means it was used as a decorator without arguments
  115. return self.add_webhook()(path)
  116. # Usage: provide a path. Example: `@app.add_webhook(...)`
  117. @wraps(FastAPI.post)
  118. def _inner_post(*args, **kwargs):
  119. func = args[0]
  120. abs_path = f"/webhooks/{(path or func.__name__).strip('/')}"
  121. if abs_path in self.registered_webhooks:
  122. raise ValueError(f"Webhook {abs_path} already exists.")
  123. self.registered_webhooks[abs_path] = func
  124. return _inner_post
  125. def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None:
  126. """Launch the Gradio app and register webhooks to the underlying FastAPI server.
  127. Input parameters are forwarded to Gradio when launching the app.
  128. """
  129. ui = self._ui or self._get_default_ui()
  130. # Start Gradio App
  131. # - as non-blocking so that webhooks can be added afterwards
  132. # - as shared if launch locally (to debug webhooks)
  133. launch_kwargs.setdefault("share", _is_local)
  134. self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs)
  135. # Register webhooks to FastAPI app
  136. for path, func in self.registered_webhooks.items():
  137. # Add secret check if required
  138. if self.webhook_secret is not None:
  139. func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret)
  140. # Add route to FastAPI app
  141. self.fastapi_app.post(path)(func)
  142. # Print instructions and block main thread
  143. space_host = os.environ.get("SPACE_HOST")
  144. url = "https://" + space_host if space_host is not None else (ui.share_url or ui.local_url)
  145. if url is None:
  146. raise ValueError("Cannot find the URL of the app. Please provide a valid `ui` or update `gradio` version.")
  147. url = url.strip("/")
  148. message = "\nWebhooks are correctly setup and ready to use:"
  149. message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks)
  150. message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks."
  151. print(message)
  152. if not prevent_thread_lock:
  153. ui.block_thread()
  154. def _get_default_ui(self) -> "gr.Blocks":
  155. """Default UI if not provided (lists webhooks and provides basic instructions)."""
  156. import gradio as gr
  157. with gr.Blocks() as ui:
  158. gr.Markdown("# This is an app to process 🤗 Webhooks")
  159. gr.Markdown(
  160. "Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on"
  161. " specific repos or to all repos belonging to particular set of users/organizations (not just your"
  162. " repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to"
  163. " know more about webhooks on the Huggingface Hub."
  164. )
  165. gr.Markdown(
  166. f"{len(self.registered_webhooks)} webhook(s) are registered:"
  167. + "\n\n"
  168. + "\n ".join(
  169. f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})"
  170. for webhook_path, webhook in self.registered_webhooks.items()
  171. )
  172. )
  173. gr.Markdown(
  174. "Go to https://huggingface.co/settings/webhooks to setup your webhooks."
  175. + "\nYou app is running locally. Please look at the logs to check the full URL you need to set."
  176. if _is_local
  177. else (
  178. "\nThis app is running on a Space. You can find the corresponding URL in the options menu"
  179. " (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'."
  180. )
  181. )
  182. return ui
  183. @experimental
  184. def webhook_endpoint(path: str | None = None) -> Callable:
  185. """Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint.
  186. This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret),
  187. you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using
  188. this decorator multiple times.
  189. Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to set up your
  190. server and deploy it on a Space.
  191. > [!WARNING]
  192. > `webhook_endpoint` is experimental. Its API is subject to change in the future.
  193. > [!WARNING]
  194. > You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`).
  195. Args:
  196. path (`str`, optional):
  197. The URL path to register the webhook function. If not provided, the function name will be used as the path.
  198. In any case, all webhooks are registered under `/webhooks`.
  199. Examples:
  200. The default usage is to register a function as a webhook endpoint. The function name will be used as the path.
  201. The server will be started automatically at exit (i.e. at the end of the script).
  202. ```python
  203. from huggingface_hub import webhook_endpoint, WebhookPayload
  204. @webhook_endpoint
  205. async def trigger_training(payload: WebhookPayload):
  206. if payload.repo.type == "dataset" and payload.event.action == "update":
  207. # Trigger a training job if a dataset is updated
  208. ...
  209. # Server is automatically started at the end of the script.
  210. ```
  211. Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you
  212. are running it in a notebook.
  213. ```python
  214. from huggingface_hub import webhook_endpoint, WebhookPayload
  215. @webhook_endpoint
  216. async def trigger_training(payload: WebhookPayload):
  217. if payload.repo.type == "dataset" and payload.event.action == "update":
  218. # Trigger a training job if a dataset is updated
  219. ...
  220. # Start the server manually
  221. trigger_training.launch()
  222. ```
  223. """
  224. if callable(path):
  225. # If path is a function, it means it was used as a decorator without arguments
  226. return webhook_endpoint()(path)
  227. @wraps(WebhooksServer.add_webhook)
  228. def _inner(func: Callable) -> Callable:
  229. app = _get_global_app()
  230. app.add_webhook(path)(func)
  231. if len(app.registered_webhooks) == 1:
  232. # Register `app.launch` to run at exit (only once)
  233. atexit.register(app.launch)
  234. @wraps(app.launch)
  235. def _launch_now():
  236. # Run the app directly (without waiting atexit)
  237. atexit.unregister(app.launch)
  238. app.launch()
  239. func.launch = _launch_now # type: ignore
  240. return func
  241. return _inner
  242. def _get_global_app() -> WebhooksServer:
  243. global _global_app
  244. if _global_app is None:
  245. _global_app = WebhooksServer()
  246. return _global_app
  247. def _warn_on_empty_secret(webhook_secret: str | None) -> None:
  248. if webhook_secret is None:
  249. print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.")
  250. print(
  251. "To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: "
  252. "\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`"
  253. )
  254. print(
  255. "For more details about webhook secrets, please refer to"
  256. " https://huggingface.co/docs/hub/webhooks#webhook-secret."
  257. )
  258. else:
  259. print("Webhook secret is correctly defined.")
  260. def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str:
  261. """Returns the anchor to a given webhook in the docs (experimental)"""
  262. return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post"
  263. def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable:
  264. """Wraps a webhook function to check the webhook secret before calling the function.
  265. This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route
  266. parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request`
  267. object (and hence the headers). A far cleaner solution would be to use a middleware. However, since
  268. `fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by
  269. Gradio internals (and not by us), we cannot add a middleware.
  270. This method is called only when a secret has been defined by the user. If a request is sent without the
  271. "x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect,
  272. the function will return a 403 error (forbidden).
  273. Inspired by https://stackoverflow.com/a/33112180.
  274. """
  275. initial_sig = inspect.signature(func)
  276. @wraps(func)
  277. async def _protected_func(request: Request, **kwargs):
  278. request_secret = request.headers.get("x-webhook-secret")
  279. if request_secret is None:
  280. return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401)
  281. if request_secret != webhook_secret:
  282. return JSONResponse({"error": "Invalid webhook secret."}, status_code=403)
  283. # Inject `request` in kwargs if required
  284. if "request" in initial_sig.parameters:
  285. kwargs["request"] = request
  286. # Handle both sync and async routes
  287. if inspect.iscoroutinefunction(func):
  288. return await func(**kwargs)
  289. else:
  290. return func(**kwargs)
  291. # Update signature to include request
  292. if "request" not in initial_sig.parameters:
  293. _protected_func.__signature__ = initial_sig.replace( # type: ignore
  294. parameters=(
  295. inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request),
  296. )
  297. + tuple(initial_sig.parameters.values())
  298. )
  299. # Return protected route
  300. return _protected_func