service_api.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from __future__ import annotations
  2. import contextlib
  3. import logging
  4. import uuid
  5. import weakref
  6. from typing import cast
  7. from wandb.proto import wandb_internal_pb2 as pb
  8. from wandb.proto.wandb_api_pb2 import ApiRequest, ApiResponse, FeaturesRequest
  9. from wandb.sdk import wandb_settings, wandb_setup
  10. from wandb.sdk.lib.service.service_connection import (
  11. ServiceConnection,
  12. WandbApiFailedError,
  13. )
  14. from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
  15. _logger = logging.getLogger(__name__)
  16. def _cleanup(connection: ServiceConnection | None, api_id: str) -> None:
  17. """Clean up the api resources associated with the api id."""
  18. if connection is not None:
  19. with contextlib.suppress(Exception):
  20. connection.api_cleanup_request(api_id)
  21. class ServiceApi:
  22. """A lazy initialized handle to the wandb-core service for handling API requests."""
  23. def __init__(
  24. self,
  25. settings: wandb_settings.Settings,
  26. ):
  27. self._settings = settings
  28. self._service_connection: ServiceConnection | None = None
  29. self._api_id = str(uuid.uuid4())
  30. def _get_service_connection(self) -> ServiceConnection:
  31. """Connects to the service and initializes resources for handling API requests."""
  32. if self._service_connection is None:
  33. self._service_connection = wandb_setup.singleton().ensure_service()
  34. response = self._service_connection.api_init_request(
  35. self._settings.to_proto(),
  36. )
  37. self._api_id = response.api_id
  38. weakref.finalize(
  39. self,
  40. _cleanup,
  41. self._service_connection,
  42. self._api_id,
  43. )
  44. return self._service_connection
  45. def send_api_request(
  46. self,
  47. request: ApiRequest,
  48. timeout: float | None = None,
  49. ) -> ApiResponse:
  50. """Send an API request to the backend service.
  51. Creates the backend service connection if it has not been created yet.
  52. """
  53. conn = self._get_service_connection()
  54. request.api_id = self._api_id
  55. return conn.api_request(request, timeout=timeout)
  56. async def send_api_request_async(
  57. self,
  58. request: ApiRequest,
  59. ) -> MailboxHandle[ApiResponse]:
  60. """Send an API request to the backend service asynchronously.
  61. Args:
  62. request: The Api request to send.
  63. timeout: The timeout for the request.
  64. """
  65. conn = self._get_service_connection()
  66. request.api_id = self._api_id
  67. return await conn.api_request_async(request)
  68. def feature_enabled(
  69. self,
  70. feature: pb.ServerFeature | str,
  71. *,
  72. timeout: float = 10,
  73. ) -> bool:
  74. """Returns whether a single server feature is enabled.
  75. On timeout or normal error, this logs and returns False.
  76. Args:
  77. feature: The enum constant or name of the boolean feature to
  78. check. Prefer to use the enum constants when possible, since
  79. they have better type-checking. For unknown or incorrect names,
  80. this returns False.
  81. timeout: The timeout to use. Defaults to 10 seconds.
  82. """
  83. if isinstance(feature, str):
  84. try:
  85. # NOTE: pb.ServerFeature is not an actual runtime type.
  86. #
  87. # All protobuf enums are represented as integers.
  88. # It is guaranteed that the return value of Value
  89. # is a valid enum (if it exists), hence the cast.
  90. feature = cast(pb.ServerFeature, pb.ServerFeature.Value(feature))
  91. except ValueError:
  92. # SERVER_FEATURE_UNSPECIFIED is always disabled.
  93. return False
  94. req = ApiRequest(features_request=FeaturesRequest(features=[feature]))
  95. try:
  96. resp = self.send_api_request(req, timeout=timeout)
  97. except WandbApiFailedError:
  98. # NOTE: The feature's integer value is logged here.
  99. _logger.exception("Failed to load feature %s", feature)
  100. return False
  101. return feature in resp.features_response.enabled