_inference_endpoints.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. import time
  2. from dataclasses import dataclass, field
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import TYPE_CHECKING, Optional
  6. from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError
  7. from .utils import get_session, logging, parse_datetime
  8. if TYPE_CHECKING:
  9. from .hf_api import HfApi
  10. from .inference._client import InferenceClient
  11. from .inference._generated._async_client import AsyncInferenceClient
  12. logger = logging.get_logger(__name__)
  13. class InferenceEndpointStatus(str, Enum):
  14. PENDING = "pending"
  15. INITIALIZING = "initializing"
  16. UPDATING = "updating"
  17. UPDATE_FAILED = "updateFailed"
  18. RUNNING = "running"
  19. PAUSED = "paused"
  20. FAILED = "failed"
  21. SCALED_TO_ZERO = "scaledToZero"
  22. class InferenceEndpointType(str, Enum):
  23. PUBlIC = "public"
  24. PROTECTED = "protected"
  25. PRIVATE = "private"
  26. class InferenceEndpointScalingMetric(str, Enum):
  27. PENDING_REQUESTS = "pendingRequests"
  28. HARDWARE_USAGE = "hardwareUsage"
  29. @dataclass
  30. class InferenceEndpoint:
  31. """
  32. Contains information about a deployed Inference Endpoint.
  33. Args:
  34. name (`str`):
  35. The unique name of the Inference Endpoint.
  36. namespace (`str`):
  37. The namespace where the Inference Endpoint is located.
  38. repository (`str`):
  39. The name of the model repository deployed on this Inference Endpoint.
  40. status ([`InferenceEndpointStatus`]):
  41. The current status of the Inference Endpoint.
  42. url (`str`, *optional*):
  43. The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL.
  44. framework (`str`):
  45. The machine learning framework used for the model.
  46. revision (`str`):
  47. The specific model revision deployed on the Inference Endpoint.
  48. task (`str`):
  49. The task associated with the deployed model.
  50. created_at (`datetime.datetime`):
  51. The timestamp when the Inference Endpoint was created.
  52. updated_at (`datetime.datetime`):
  53. The timestamp of the last update of the Inference Endpoint.
  54. type ([`InferenceEndpointType`]):
  55. The type of the Inference Endpoint (public, protected, private).
  56. raw (`dict`):
  57. The raw dictionary data returned from the API.
  58. token (`str` or `bool`, *optional*):
  59. Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the
  60. locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server.
  61. Example:
  62. ```python
  63. >>> from huggingface_hub import get_inference_endpoint
  64. >>> endpoint = get_inference_endpoint("my-text-to-image")
  65. >>> endpoint
  66. InferenceEndpoint(name='my-text-to-image', ...)
  67. # Get status
  68. >>> endpoint.status
  69. 'running'
  70. >>> endpoint.url
  71. 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud'
  72. # Run inference
  73. >>> endpoint.client.text_to_image(...)
  74. # Pause endpoint to save $$$
  75. >>> endpoint.pause()
  76. # ...
  77. # Resume and wait for deployment
  78. >>> endpoint.resume()
  79. >>> endpoint.wait()
  80. >>> endpoint.client.text_to_image(...)
  81. ```
  82. """
  83. # Field in __repr__
  84. name: str = field(init=False)
  85. namespace: str
  86. repository: str = field(init=False)
  87. status: InferenceEndpointStatus = field(init=False)
  88. health_route: str = field(init=False)
  89. url: str | None = field(init=False)
  90. # Other fields
  91. framework: str = field(repr=False, init=False)
  92. revision: str = field(repr=False, init=False)
  93. task: str = field(repr=False, init=False)
  94. created_at: datetime = field(repr=False, init=False)
  95. updated_at: datetime = field(repr=False, init=False)
  96. type: InferenceEndpointType = field(repr=False, init=False)
  97. # Raw dict from the API
  98. raw: dict = field(repr=False)
  99. # Internal fields
  100. _token: str | bool | None = field(repr=False, compare=False)
  101. _api: "HfApi" = field(repr=False, compare=False)
  102. @classmethod
  103. def from_raw(
  104. cls, raw: dict, namespace: str, token: str | bool | None = None, api: Optional["HfApi"] = None
  105. ) -> "InferenceEndpoint":
  106. """Initialize object from raw dictionary."""
  107. if api is None:
  108. from .hf_api import HfApi
  109. api = HfApi()
  110. if token is None:
  111. token = api.token
  112. # All other fields are populated in __post_init__
  113. return cls(raw=raw, namespace=namespace, _token=token, _api=api)
  114. def __post_init__(self) -> None:
  115. """Populate fields from raw dictionary."""
  116. self._populate_from_raw()
  117. @property
  118. def client(self) -> "InferenceClient":
  119. """Returns a client to make predictions on this Inference Endpoint.
  120. Returns:
  121. [`InferenceClient`]: an inference client pointing to the deployed endpoint.
  122. Raises:
  123. [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
  124. """
  125. if self.url is None:
  126. raise InferenceEndpointError(
  127. "Cannot create a client for this Inference Endpoint as it is not yet deployed. "
  128. "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
  129. )
  130. from .inference._client import InferenceClient
  131. return InferenceClient(
  132. model=self.url,
  133. token=self._token, # type: ignore # boolean token shouldn't be possible. In practice it's ok.
  134. )
  135. @property
  136. def async_client(self) -> "AsyncInferenceClient":
  137. """Returns a client to make predictions on this Inference Endpoint.
  138. Returns:
  139. [`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint.
  140. Raises:
  141. [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
  142. """
  143. if self.url is None:
  144. raise InferenceEndpointError(
  145. "Cannot create a client for this Inference Endpoint as it is not yet deployed. "
  146. "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
  147. )
  148. from .inference._generated._async_client import AsyncInferenceClient
  149. return AsyncInferenceClient(
  150. model=self.url,
  151. token=self._token, # type: ignore # boolean token shouldn't be possible. In practice it's ok.
  152. )
  153. def wait(self, timeout: int | None = None, refresh_every: int = 5) -> "InferenceEndpoint":
  154. """Wait for the Inference Endpoint to be deployed.
  155. Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout`
  156. seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest
  157. data.
  158. Args:
  159. timeout (`int`, *optional*):
  160. The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait
  161. indefinitely.
  162. refresh_every (`int`, *optional*):
  163. The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s.
  164. Returns:
  165. [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
  166. Raises:
  167. [`InferenceEndpointError`]
  168. If the Inference Endpoint ended up in a failed state.
  169. [`InferenceEndpointTimeoutError`]
  170. If the Inference Endpoint is not deployed after `timeout` seconds.
  171. """
  172. if timeout is not None and timeout < 0:
  173. raise ValueError("`timeout` cannot be negative.")
  174. if refresh_every <= 0:
  175. raise ValueError("`refresh_every` must be positive.")
  176. start = time.time()
  177. while True:
  178. if self.status == InferenceEndpointStatus.FAILED:
  179. raise InferenceEndpointError(
  180. f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information."
  181. )
  182. if self.status == InferenceEndpointStatus.UPDATE_FAILED:
  183. raise InferenceEndpointError(
  184. f"Inference Endpoint {self.name} failed to update. Please check the logs for more information."
  185. )
  186. if self.status == InferenceEndpointStatus.RUNNING and self.url is not None:
  187. # Verify the endpoint is actually reachable
  188. _health_url = f"{self.url.rstrip('/')}/{self.health_route.lstrip('/')}"
  189. response = get_session().get(_health_url, headers=self._api._build_hf_headers(token=self._token))
  190. if response.status_code == 200:
  191. logger.info("Inference Endpoint is ready to be used.")
  192. return self
  193. if timeout is not None:
  194. if time.time() - start > timeout:
  195. raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.")
  196. logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...")
  197. time.sleep(refresh_every)
  198. self.fetch()
  199. def fetch(self) -> "InferenceEndpoint":
  200. """Fetch latest information about the Inference Endpoint.
  201. Returns:
  202. [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
  203. """
  204. obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
  205. self.raw = obj.raw
  206. self._populate_from_raw()
  207. return self
  208. def update(
  209. self,
  210. *,
  211. # Compute update
  212. accelerator: str | None = None,
  213. instance_size: str | None = None,
  214. instance_type: str | None = None,
  215. min_replica: int | None = None,
  216. max_replica: int | None = None,
  217. scale_to_zero_timeout: int | None = None,
  218. # Model update
  219. repository: str | None = None,
  220. framework: str | None = None,
  221. revision: str | None = None,
  222. task: str | None = None,
  223. custom_image: dict | None = None,
  224. secrets: dict[str, str] | None = None,
  225. ) -> "InferenceEndpoint":
  226. """Update the Inference Endpoint.
  227. This method allows the update of either the compute configuration, the deployed model, or both. All arguments are
  228. optional but at least one must be provided.
  229. This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the
  230. latest data from the server.
  231. Args:
  232. accelerator (`str`, *optional*):
  233. The hardware accelerator to be used for inference (e.g. `"cpu"`).
  234. instance_size (`str`, *optional*):
  235. The size or type of the instance to be used for hosting the model (e.g. `"x4"`).
  236. instance_type (`str`, *optional*):
  237. The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`).
  238. min_replica (`int`, *optional*):
  239. The minimum number of replicas (instances) to keep running for the Inference Endpoint.
  240. max_replica (`int`, *optional*):
  241. The maximum number of replicas (instances) to scale to for the Inference Endpoint.
  242. scale_to_zero_timeout (`int`, *optional*):
  243. The duration in minutes before an inactive endpoint is scaled to zero.
  244. repository (`str`, *optional*):
  245. The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
  246. framework (`str`, *optional*):
  247. The machine learning framework used for the model (e.g. `"custom"`).
  248. revision (`str`, *optional*):
  249. The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
  250. task (`str`, *optional*):
  251. The task on which to deploy the model (e.g. `"text-classification"`).
  252. custom_image (`dict`, *optional*):
  253. A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an
  254. Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples).
  255. secrets (`dict[str, str]`, *optional*):
  256. Secret values to inject in the container environment.
  257. Returns:
  258. [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
  259. """
  260. # Make API call
  261. obj = self._api.update_inference_endpoint(
  262. name=self.name,
  263. namespace=self.namespace,
  264. accelerator=accelerator,
  265. instance_size=instance_size,
  266. instance_type=instance_type,
  267. min_replica=min_replica,
  268. max_replica=max_replica,
  269. scale_to_zero_timeout=scale_to_zero_timeout,
  270. repository=repository,
  271. framework=framework,
  272. revision=revision,
  273. task=task,
  274. custom_image=custom_image,
  275. secrets=secrets,
  276. token=self._token, # type: ignore [arg-type]
  277. )
  278. # Mutate current object
  279. self.raw = obj.raw
  280. self._populate_from_raw()
  281. return self
  282. def pause(self) -> "InferenceEndpoint":
  283. """Pause the Inference Endpoint.
  284. A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`].
  285. This is different from scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which
  286. would be automatically restarted when a request is made to it.
  287. This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the
  288. latest data from the server.
  289. Returns:
  290. [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
  291. """
  292. obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
  293. self.raw = obj.raw
  294. self._populate_from_raw()
  295. return self
  296. def resume(self, running_ok: bool = True) -> "InferenceEndpoint":
  297. """Resume the Inference Endpoint.
  298. This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the
  299. latest data from the server.
  300. Args:
  301. running_ok (`bool`, *optional*):
  302. If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to
  303. `True`.
  304. Returns:
  305. [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
  306. """
  307. obj = self._api.resume_inference_endpoint(
  308. name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token
  309. ) # type: ignore [arg-type]
  310. self.raw = obj.raw
  311. self._populate_from_raw()
  312. return self
  313. def scale_to_zero(self) -> "InferenceEndpoint":
  314. """Scale Inference Endpoint to zero.
  315. An Inference Endpoint scaled to zero will not be charged. It will be resumed on the next request to it, with a
  316. cold start delay. This is different from pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which
  317. would require a manual resume with [`InferenceEndpoint.resume`].
  318. This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the
  319. latest data from the server.
  320. Returns:
  321. [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
  322. """
  323. obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
  324. self.raw = obj.raw
  325. self._populate_from_raw()
  326. return self
  327. def delete(self) -> None:
  328. """Delete the Inference Endpoint.
  329. This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable
  330. to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`].
  331. This is an alias for [`HfApi.delete_inference_endpoint`].
  332. """
  333. self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
  334. def _populate_from_raw(self) -> None:
  335. """Populate fields from raw dictionary.
  336. Called in __post_init__ + each time the Inference Endpoint is updated.
  337. """
  338. # Repr fields
  339. self.name = self.raw["name"]
  340. self.repository = self.raw["model"]["repository"]
  341. self.status = self.raw["status"]["state"]
  342. self.url = self.raw["status"].get("url")
  343. self.health_route = self.raw["healthRoute"]
  344. # Other fields
  345. self.framework = self.raw["model"]["framework"]
  346. self.revision = self.raw["model"]["revision"]
  347. self.task = self.raw["model"]["task"]
  348. self.created_at = parse_datetime(self.raw["status"]["createdAt"])
  349. self.updated_at = parse_datetime(self.raw["status"]["updatedAt"])
  350. self.type = self.raw["type"]