from __future__ import annotations import contextlib import logging import uuid import weakref from typing import cast from wandb.proto import wandb_internal_pb2 as pb from wandb.proto.wandb_api_pb2 import ApiRequest, ApiResponse, FeaturesRequest from wandb.sdk import wandb_settings, wandb_setup from wandb.sdk.lib.service.service_connection import ( ServiceConnection, WandbApiFailedError, ) from wandb.sdk.mailbox.mailbox_handle import MailboxHandle _logger = logging.getLogger(__name__) def _cleanup(connection: ServiceConnection | None, api_id: str) -> None: """Clean up the api resources associated with the api id.""" if connection is not None: with contextlib.suppress(Exception): connection.api_cleanup_request(api_id) class ServiceApi: """A lazy initialized handle to the wandb-core service for handling API requests.""" def __init__( self, settings: wandb_settings.Settings, ): self._settings = settings self._service_connection: ServiceConnection | None = None self._api_id = str(uuid.uuid4()) def _get_service_connection(self) -> ServiceConnection: """Connects to the service and initializes resources for handling API requests.""" if self._service_connection is None: self._service_connection = wandb_setup.singleton().ensure_service() response = self._service_connection.api_init_request( self._settings.to_proto(), ) self._api_id = response.api_id weakref.finalize( self, _cleanup, self._service_connection, self._api_id, ) return self._service_connection def send_api_request( self, request: ApiRequest, timeout: float | None = None, ) -> ApiResponse: """Send an API request to the backend service. Creates the backend service connection if it has not been created yet. """ conn = self._get_service_connection() request.api_id = self._api_id return conn.api_request(request, timeout=timeout) async def send_api_request_async( self, request: ApiRequest, ) -> MailboxHandle[ApiResponse]: """Send an API request to the backend service asynchronously. Args: request: The Api request to send. timeout: The timeout for the request. """ conn = self._get_service_connection() request.api_id = self._api_id return await conn.api_request_async(request) def feature_enabled( self, feature: pb.ServerFeature | str, *, timeout: float = 10, ) -> bool: """Returns whether a single server feature is enabled. On timeout or normal error, this logs and returns False. Args: feature: The enum constant or name of the boolean feature to check. Prefer to use the enum constants when possible, since they have better type-checking. For unknown or incorrect names, this returns False. timeout: The timeout to use. Defaults to 10 seconds. """ if isinstance(feature, str): try: # NOTE: pb.ServerFeature is not an actual runtime type. # # All protobuf enums are represented as integers. # It is guaranteed that the return value of Value # is a valid enum (if it exists), hence the cast. feature = cast(pb.ServerFeature, pb.ServerFeature.Value(feature)) except ValueError: # SERVER_FEATURE_UNSPECIFIED is always disabled. return False req = ApiRequest(features_request=FeaturesRequest(features=[feature])) try: resp = self.send_api_request(req, timeout=timeout) except WandbApiFailedError: # NOTE: The feature's integer value is logged here. _logger.exception("Failed to load feature %s", feature) return False return feature in resp.features_response.enabled