service_client.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from __future__ import annotations
  2. import asyncio
  3. import logging
  4. import struct
  5. import sys
  6. from types import TracebackType
  7. from wandb.proto import wandb_server_pb2 as spb
  8. from wandb.sdk.lib import asyncio_manager
  9. from wandb.sdk.mailbox.mailbox import Mailbox
  10. from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
  11. _logger = logging.getLogger(__name__)
  12. _HEADER_BYTE_INT_LEN = 5
  13. _HEADER_BYTE_INT_FMT = "<BI"
  14. class ServiceClient:
  15. """Implements socket communication with the internal service."""
  16. def __init__(
  17. self,
  18. asyncer: asyncio_manager.AsyncioManager,
  19. reader: asyncio.StreamReader,
  20. writer: asyncio.StreamWriter,
  21. ) -> None:
  22. self._broken_exc: Exception | None = None
  23. self._broken_tb: TracebackType | None = None
  24. self._drain_lock: asyncio.Lock | None = None
  25. self._reader = reader
  26. self._writer = writer
  27. self._mailbox = Mailbox(asyncer, self._cancel_request)
  28. asyncer.run_soon(
  29. self._forward_responses,
  30. daemon=True,
  31. name="ServiceClient._forward_responses",
  32. )
  33. async def publish(self, request: spb.ServerRequest) -> None:
  34. """Send a request without waiting for a response."""
  35. await self._send_server_request(request)
  36. async def deliver(
  37. self,
  38. request: spb.ServerRequest,
  39. ) -> MailboxHandle[spb.ServerResponse]:
  40. """Send a request and return a handle to wait for a response.
  41. NOTE: This may mutate the request. The request should not be used
  42. after.
  43. Raises:
  44. MailboxClosedError: If used after the client is closed or has
  45. stopped due to an error.
  46. """
  47. handle = self._mailbox.require_response(request)
  48. await self._send_server_request(request)
  49. return handle
  50. async def _send_server_request(self, request: spb.ServerRequest) -> None:
  51. if self._broken_exc:
  52. # Use with_traceback() to reuse the original traceback.
  53. # The exception's __traceback__ is modified by every `raise`
  54. # statement, so we must reset it to the original value.
  55. # The caller will receive an exception whose traceback has this
  56. # `raise` statement, followed by the `await self._writer.drain()`
  57. # statement, followed by the traceback there.
  58. #
  59. # We do this because `StreamWriter` stores an exception and doesn't
  60. # correctly reset the traceback when reraising (at least in older
  61. # Python versions).
  62. #
  63. # See https://bugs.python.org/issue45924.
  64. raise self._broken_exc.with_traceback(self._broken_tb)
  65. header = struct.pack(_HEADER_BYTE_INT_FMT, ord("W"), request.ByteSize())
  66. self._writer.write(header)
  67. data = request.SerializeToString()
  68. self._writer.write(data)
  69. try:
  70. await self._drain_writer()
  71. except Exception as e:
  72. self._broken_exc = e
  73. self._broken_tb = e.__traceback__
  74. raise
  75. async def _drain_writer(self) -> None:
  76. """Wait for the socket's flow control."""
  77. if sys.version_info >= (3, 10):
  78. await self._writer.drain()
  79. return
  80. # Prior to 3.10, drain() incorrectly raised an AssertionError when the
  81. # write buffer was maxed out if called from more than one async task.
  82. self._drain_lock = self._drain_lock or asyncio.Lock()
  83. async with self._drain_lock:
  84. await self._writer.drain()
  85. async def _cancel_request(self, id: str, /) -> None:
  86. """Cancel a request by ID.
  87. Args:
  88. id: The request_id of a previously-sent ServerRequest.
  89. """
  90. await self.publish(
  91. spb.ServerRequest(
  92. cancel=spb.ServerCancelRequest(
  93. request_id=id,
  94. )
  95. )
  96. )
  97. async def close(self) -> None:
  98. """Flush and close the socket."""
  99. self._writer.close()
  100. await self._writer.wait_closed()
  101. async def _forward_responses(self) -> None:
  102. try:
  103. while response := await self._read_server_response():
  104. await self._mailbox.deliver(response)
  105. except Exception:
  106. _logger.exception("Error reading server response.")
  107. else:
  108. _logger.info("Reached EOF.")
  109. finally:
  110. self._mailbox.close()
  111. async def _read_server_response(self) -> spb.ServerResponse | None:
  112. try:
  113. header = await self._reader.readexactly(_HEADER_BYTE_INT_LEN)
  114. except asyncio.IncompleteReadError as e:
  115. if e.partial:
  116. raise
  117. else:
  118. return None
  119. magic, length = struct.unpack(_HEADER_BYTE_INT_FMT, header)
  120. if magic != ord("W"):
  121. raise ValueError(f"Bad header: {header.hex()}")
  122. data = await self._reader.readexactly(length)
  123. response = spb.ServerResponse()
  124. response.ParseFromString(data)
  125. return response