| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418 |
- import time
- from dataclasses import dataclass, field
- from datetime import datetime
- from enum import Enum
- from typing import TYPE_CHECKING, Optional
- from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError
- from .utils import get_session, logging, parse_datetime
- if TYPE_CHECKING:
- from .hf_api import HfApi
- from .inference._client import InferenceClient
- from .inference._generated._async_client import AsyncInferenceClient
- logger = logging.get_logger(__name__)
- class InferenceEndpointStatus(str, Enum):
- PENDING = "pending"
- INITIALIZING = "initializing"
- UPDATING = "updating"
- UPDATE_FAILED = "updateFailed"
- RUNNING = "running"
- PAUSED = "paused"
- FAILED = "failed"
- SCALED_TO_ZERO = "scaledToZero"
- class InferenceEndpointType(str, Enum):
- PUBlIC = "public"
- PROTECTED = "protected"
- PRIVATE = "private"
- class InferenceEndpointScalingMetric(str, Enum):
- PENDING_REQUESTS = "pendingRequests"
- HARDWARE_USAGE = "hardwareUsage"
- @dataclass
- class InferenceEndpoint:
- """
- Contains information about a deployed Inference Endpoint.
- Args:
- name (`str`):
- The unique name of the Inference Endpoint.
- namespace (`str`):
- The namespace where the Inference Endpoint is located.
- repository (`str`):
- The name of the model repository deployed on this Inference Endpoint.
- status ([`InferenceEndpointStatus`]):
- The current status of the Inference Endpoint.
- url (`str`, *optional*):
- The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL.
- framework (`str`):
- The machine learning framework used for the model.
- revision (`str`):
- The specific model revision deployed on the Inference Endpoint.
- task (`str`):
- The task associated with the deployed model.
- created_at (`datetime.datetime`):
- The timestamp when the Inference Endpoint was created.
- updated_at (`datetime.datetime`):
- The timestamp of the last update of the Inference Endpoint.
- type ([`InferenceEndpointType`]):
- The type of the Inference Endpoint (public, protected, private).
- raw (`dict`):
- The raw dictionary data returned from the API.
- token (`str` or `bool`, *optional*):
- Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the
- locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server.
- Example:
- ```python
- >>> from huggingface_hub import get_inference_endpoint
- >>> endpoint = get_inference_endpoint("my-text-to-image")
- >>> endpoint
- InferenceEndpoint(name='my-text-to-image', ...)
- # Get status
- >>> endpoint.status
- 'running'
- >>> endpoint.url
- 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud'
- # Run inference
- >>> endpoint.client.text_to_image(...)
- # Pause endpoint to save $$$
- >>> endpoint.pause()
- # ...
- # Resume and wait for deployment
- >>> endpoint.resume()
- >>> endpoint.wait()
- >>> endpoint.client.text_to_image(...)
- ```
- """
- # Field in __repr__
- name: str = field(init=False)
- namespace: str
- repository: str = field(init=False)
- status: InferenceEndpointStatus = field(init=False)
- health_route: str = field(init=False)
- url: str | None = field(init=False)
- # Other fields
- framework: str = field(repr=False, init=False)
- revision: str = field(repr=False, init=False)
- task: str = field(repr=False, init=False)
- created_at: datetime = field(repr=False, init=False)
- updated_at: datetime = field(repr=False, init=False)
- type: InferenceEndpointType = field(repr=False, init=False)
- # Raw dict from the API
- raw: dict = field(repr=False)
- # Internal fields
- _token: str | bool | None = field(repr=False, compare=False)
- _api: "HfApi" = field(repr=False, compare=False)
- @classmethod
- def from_raw(
- cls, raw: dict, namespace: str, token: str | bool | None = None, api: Optional["HfApi"] = None
- ) -> "InferenceEndpoint":
- """Initialize object from raw dictionary."""
- if api is None:
- from .hf_api import HfApi
- api = HfApi()
- if token is None:
- token = api.token
- # All other fields are populated in __post_init__
- return cls(raw=raw, namespace=namespace, _token=token, _api=api)
- def __post_init__(self) -> None:
- """Populate fields from raw dictionary."""
- self._populate_from_raw()
- @property
- def client(self) -> "InferenceClient":
- """Returns a client to make predictions on this Inference Endpoint.
- Returns:
- [`InferenceClient`]: an inference client pointing to the deployed endpoint.
- Raises:
- [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
- """
- if self.url is None:
- raise InferenceEndpointError(
- "Cannot create a client for this Inference Endpoint as it is not yet deployed. "
- "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
- )
- from .inference._client import InferenceClient
- return InferenceClient(
- model=self.url,
- token=self._token, # type: ignore # boolean token shouldn't be possible. In practice it's ok.
- )
- @property
- def async_client(self) -> "AsyncInferenceClient":
- """Returns a client to make predictions on this Inference Endpoint.
- Returns:
- [`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint.
- Raises:
- [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
- """
- if self.url is None:
- raise InferenceEndpointError(
- "Cannot create a client for this Inference Endpoint as it is not yet deployed. "
- "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
- )
- from .inference._generated._async_client import AsyncInferenceClient
- return AsyncInferenceClient(
- model=self.url,
- token=self._token, # type: ignore # boolean token shouldn't be possible. In practice it's ok.
- )
- def wait(self, timeout: int | None = None, refresh_every: int = 5) -> "InferenceEndpoint":
- """Wait for the Inference Endpoint to be deployed.
- Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout`
- seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest
- data.
- Args:
- timeout (`int`, *optional*):
- The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait
- indefinitely.
- refresh_every (`int`, *optional*):
- The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s.
- Returns:
- [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
- Raises:
- [`InferenceEndpointError`]
- If the Inference Endpoint ended up in a failed state.
- [`InferenceEndpointTimeoutError`]
- If the Inference Endpoint is not deployed after `timeout` seconds.
- """
- if timeout is not None and timeout < 0:
- raise ValueError("`timeout` cannot be negative.")
- if refresh_every <= 0:
- raise ValueError("`refresh_every` must be positive.")
- start = time.time()
- while True:
- if self.status == InferenceEndpointStatus.FAILED:
- raise InferenceEndpointError(
- f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information."
- )
- if self.status == InferenceEndpointStatus.UPDATE_FAILED:
- raise InferenceEndpointError(
- f"Inference Endpoint {self.name} failed to update. Please check the logs for more information."
- )
- if self.status == InferenceEndpointStatus.RUNNING and self.url is not None:
- # Verify the endpoint is actually reachable
- _health_url = f"{self.url.rstrip('/')}/{self.health_route.lstrip('/')}"
- response = get_session().get(_health_url, headers=self._api._build_hf_headers(token=self._token))
- if response.status_code == 200:
- logger.info("Inference Endpoint is ready to be used.")
- return self
- if timeout is not None:
- if time.time() - start > timeout:
- raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.")
- logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...")
- time.sleep(refresh_every)
- self.fetch()
- def fetch(self) -> "InferenceEndpoint":
- """Fetch latest information about the Inference Endpoint.
- Returns:
- [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
- """
- obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
- self.raw = obj.raw
- self._populate_from_raw()
- return self
- def update(
- self,
- *,
- # Compute update
- accelerator: str | None = None,
- instance_size: str | None = None,
- instance_type: str | None = None,
- min_replica: int | None = None,
- max_replica: int | None = None,
- scale_to_zero_timeout: int | None = None,
- # Model update
- repository: str | None = None,
- framework: str | None = None,
- revision: str | None = None,
- task: str | None = None,
- custom_image: dict | None = None,
- secrets: dict[str, str] | None = None,
- ) -> "InferenceEndpoint":
- """Update the Inference Endpoint.
- This method allows the update of either the compute configuration, the deployed model, or both. All arguments are
- optional but at least one must be provided.
- This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the
- latest data from the server.
- Args:
- accelerator (`str`, *optional*):
- The hardware accelerator to be used for inference (e.g. `"cpu"`).
- instance_size (`str`, *optional*):
- The size or type of the instance to be used for hosting the model (e.g. `"x4"`).
- instance_type (`str`, *optional*):
- The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`).
- min_replica (`int`, *optional*):
- The minimum number of replicas (instances) to keep running for the Inference Endpoint.
- max_replica (`int`, *optional*):
- The maximum number of replicas (instances) to scale to for the Inference Endpoint.
- scale_to_zero_timeout (`int`, *optional*):
- The duration in minutes before an inactive endpoint is scaled to zero.
- repository (`str`, *optional*):
- The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
- framework (`str`, *optional*):
- The machine learning framework used for the model (e.g. `"custom"`).
- revision (`str`, *optional*):
- The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
- task (`str`, *optional*):
- The task on which to deploy the model (e.g. `"text-classification"`).
- custom_image (`dict`, *optional*):
- A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an
- Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples).
- secrets (`dict[str, str]`, *optional*):
- Secret values to inject in the container environment.
- Returns:
- [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
- """
- # Make API call
- obj = self._api.update_inference_endpoint(
- name=self.name,
- namespace=self.namespace,
- accelerator=accelerator,
- instance_size=instance_size,
- instance_type=instance_type,
- min_replica=min_replica,
- max_replica=max_replica,
- scale_to_zero_timeout=scale_to_zero_timeout,
- repository=repository,
- framework=framework,
- revision=revision,
- task=task,
- custom_image=custom_image,
- secrets=secrets,
- token=self._token, # type: ignore [arg-type]
- )
- # Mutate current object
- self.raw = obj.raw
- self._populate_from_raw()
- return self
- def pause(self) -> "InferenceEndpoint":
- """Pause the Inference Endpoint.
- A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`].
- This is different from scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which
- would be automatically restarted when a request is made to it.
- This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the
- latest data from the server.
- Returns:
- [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
- """
- obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
- self.raw = obj.raw
- self._populate_from_raw()
- return self
- def resume(self, running_ok: bool = True) -> "InferenceEndpoint":
- """Resume the Inference Endpoint.
- This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the
- latest data from the server.
- Args:
- running_ok (`bool`, *optional*):
- If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to
- `True`.
- Returns:
- [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
- """
- obj = self._api.resume_inference_endpoint(
- name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token
- ) # type: ignore [arg-type]
- self.raw = obj.raw
- self._populate_from_raw()
- return self
- def scale_to_zero(self) -> "InferenceEndpoint":
- """Scale Inference Endpoint to zero.
- An Inference Endpoint scaled to zero will not be charged. It will be resumed on the next request to it, with a
- cold start delay. This is different from pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which
- would require a manual resume with [`InferenceEndpoint.resume`].
- This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the
- latest data from the server.
- Returns:
- [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
- """
- obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
- self.raw = obj.raw
- self._populate_from_raw()
- return self
- def delete(self) -> None:
- """Delete the Inference Endpoint.
- This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable
- to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`].
- This is an alias for [`HfApi.delete_inference_endpoint`].
- """
- self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
- def _populate_from_raw(self) -> None:
- """Populate fields from raw dictionary.
- Called in __post_init__ + each time the Inference Endpoint is updated.
- """
- # Repr fields
- self.name = self.raw["name"]
- self.repository = self.raw["model"]["repository"]
- self.status = self.raw["status"]["state"]
- self.url = self.raw["status"].get("url")
- self.health_route = self.raw["healthRoute"]
- # Other fields
- self.framework = self.raw["model"]["framework"]
- self.revision = self.raw["model"]["revision"]
- self.task = self.raw["model"]["task"]
- self.created_at = parse_datetime(self.raw["status"]["createdAt"])
- self.updated_at = parse_datetime(self.raw["status"]["updatedAt"])
- self.type = self.raw["type"]
|