| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- from __future__ import annotations
- import contextlib
- import logging
- import uuid
- import weakref
- from typing import cast
- from wandb.proto import wandb_internal_pb2 as pb
- from wandb.proto.wandb_api_pb2 import ApiRequest, ApiResponse, FeaturesRequest
- from wandb.sdk import wandb_settings, wandb_setup
- from wandb.sdk.lib.service.service_connection import (
- ServiceConnection,
- WandbApiFailedError,
- )
- from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
- _logger = logging.getLogger(__name__)
- def _cleanup(connection: ServiceConnection | None, api_id: str) -> None:
- """Clean up the api resources associated with the api id."""
- if connection is not None:
- with contextlib.suppress(Exception):
- connection.api_cleanup_request(api_id)
- class ServiceApi:
- """A lazy initialized handle to the wandb-core service for handling API requests."""
- def __init__(
- self,
- settings: wandb_settings.Settings,
- ):
- self._settings = settings
- self._service_connection: ServiceConnection | None = None
- self._api_id = str(uuid.uuid4())
- def _get_service_connection(self) -> ServiceConnection:
- """Connects to the service and initializes resources for handling API requests."""
- if self._service_connection is None:
- self._service_connection = wandb_setup.singleton().ensure_service()
- response = self._service_connection.api_init_request(
- self._settings.to_proto(),
- )
- self._api_id = response.api_id
- weakref.finalize(
- self,
- _cleanup,
- self._service_connection,
- self._api_id,
- )
- return self._service_connection
- def send_api_request(
- self,
- request: ApiRequest,
- timeout: float | None = None,
- ) -> ApiResponse:
- """Send an API request to the backend service.
- Creates the backend service connection if it has not been created yet.
- """
- conn = self._get_service_connection()
- request.api_id = self._api_id
- return conn.api_request(request, timeout=timeout)
- async def send_api_request_async(
- self,
- request: ApiRequest,
- ) -> MailboxHandle[ApiResponse]:
- """Send an API request to the backend service asynchronously.
- Args:
- request: The Api request to send.
- timeout: The timeout for the request.
- """
- conn = self._get_service_connection()
- request.api_id = self._api_id
- return await conn.api_request_async(request)
- def feature_enabled(
- self,
- feature: pb.ServerFeature | str,
- *,
- timeout: float = 10,
- ) -> bool:
- """Returns whether a single server feature is enabled.
- On timeout or normal error, this logs and returns False.
- Args:
- feature: The enum constant or name of the boolean feature to
- check. Prefer to use the enum constants when possible, since
- they have better type-checking. For unknown or incorrect names,
- this returns False.
- timeout: The timeout to use. Defaults to 10 seconds.
- """
- if isinstance(feature, str):
- try:
- # NOTE: pb.ServerFeature is not an actual runtime type.
- #
- # All protobuf enums are represented as integers.
- # It is guaranteed that the return value of Value
- # is a valid enum (if it exists), hence the cast.
- feature = cast(pb.ServerFeature, pb.ServerFeature.Value(feature))
- except ValueError:
- # SERVER_FEATURE_UNSPECIFIED is always disabled.
- return False
- req = ApiRequest(features_request=FeaturesRequest(features=[feature]))
- try:
- resp = self.send_api_request(req, timeout=timeout)
- except WandbApiFailedError:
- # NOTE: The feature's integer value is logged here.
- _logger.exception("Failed to load feature %s", feature)
- return False
- return feature in resp.features_response.enabled
|