mailbox.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from __future__ import annotations
  2. import logging
  3. import secrets
  4. import string
  5. import threading
  6. from collections.abc import Awaitable
  7. from typing import Callable
  8. from wandb.proto import wandb_internal_pb2 as pb
  9. from wandb.proto import wandb_server_pb2 as spb
  10. from wandb.sdk.lib import asyncio_manager
  11. from .mailbox_handle import MailboxHandle
  12. from .response_handle import MailboxResponseHandle
  13. _logger = logging.getLogger(__name__)
  14. class MailboxClosedError(Exception):
  15. """The mailbox has been closed and cannot be used."""
  16. class Mailbox:
  17. """Matches service responses to requests.
  18. The mailbox can set an address on a server request and create a handle for
  19. waiting for a response to that record. Responses are delivered by calling
  20. `deliver()`. The `close()` method abandons all handles in case the
  21. service process becomes unreachable.
  22. """
  23. def __init__(
  24. self,
  25. asyncer: asyncio_manager.AsyncioManager,
  26. cancel: Callable[[str], Awaitable[None]],
  27. ) -> None:
  28. """Create a mailbox.
  29. Args:
  30. asyncer: Asyncio runner for scheduling async operations.
  31. cancel: A callback that can be used to cancel a request by ID.
  32. """
  33. self._asyncer = asyncer
  34. self._cancel = cancel
  35. self._handles: dict[str, MailboxResponseHandle] = {}
  36. self._handles_lock = threading.Lock()
  37. self._closed = False
  38. def require_response(
  39. self,
  40. request: spb.ServerRequest | pb.Record,
  41. ) -> MailboxHandle[spb.ServerResponse]:
  42. """Set a response address on a request.
  43. Args:
  44. request: The request on which to set a request ID or mailbox slot.
  45. This is mutated. An address must not already be set.
  46. Returns:
  47. A handle for waiting for the response to the request.
  48. Raises:
  49. MailboxClosedError: If the mailbox has been closed, in which case
  50. no new responses are expected to be delivered and new handles
  51. cannot be created.
  52. """
  53. if isinstance(request, spb.ServerRequest):
  54. if (address := request.request_id) or (
  55. address := request.record_publish.control.mailbox_slot
  56. ):
  57. raise ValueError(f"Request already has an address ({address})")
  58. address = self._new_address()
  59. request.request_id = address
  60. if request.HasField("record_publish"):
  61. request.record_publish.control.mailbox_slot = address
  62. if request.HasField("record_communicate"):
  63. request.record_communicate.control.mailbox_slot = address
  64. else:
  65. if address := request.control.mailbox_slot:
  66. raise ValueError(f"Request already has an address ({address})")
  67. address = self._new_address()
  68. request.control.mailbox_slot = address
  69. with self._handles_lock:
  70. if self._closed:
  71. raise MailboxClosedError()
  72. handle = MailboxResponseHandle(
  73. address,
  74. asyncer=self._asyncer,
  75. cancel=self._cancel,
  76. )
  77. self._handles[address] = handle
  78. return handle
  79. def _new_address(self) -> str:
  80. """Returns an unused address for a request.
  81. Assumes `_handles_lock` is held.
  82. """
  83. def generate():
  84. return "".join(
  85. secrets.choice(string.ascii_lowercase + string.digits)
  86. for _ in range(12)
  87. )
  88. address = generate()
  89. # Being extra cautious. This loop will almost never be entered.
  90. while address in self._handles:
  91. address = generate()
  92. return address
  93. async def deliver(self, response: spb.ServerResponse) -> None:
  94. """Deliver a response from the service.
  95. If the response address is invalid, this does nothing.
  96. It is a no-op if the mailbox has been closed.
  97. """
  98. address = response.request_id
  99. if not address:
  100. kind: str | None = response.WhichOneof("server_response_type")
  101. if kind == "result_communicate":
  102. result_type = response.result_communicate.WhichOneof("result_type")
  103. kind = f"result_communicate.{result_type}"
  104. _logger.error(f"Received response with no mailbox slot: {kind}")
  105. return
  106. with self._handles_lock:
  107. # NOTE: If the mailbox is closed, this returns None because
  108. # we clear the dict.
  109. handle = self._handles.pop(address, None)
  110. # It is not an error if there is no handle for the address:
  111. # handles can be abandoned if the result is no longer needed.
  112. if handle:
  113. await handle.deliver(response)
  114. def close(self) -> None:
  115. """Indicate no further responses will be delivered.
  116. Abandons all handles.
  117. """
  118. with self._handles_lock:
  119. self._closed = True
  120. _logger.info(
  121. f"Closing mailbox, abandoning {len(self._handles)} handles.",
  122. )
  123. for handle in self._handles.values():
  124. handle.abandon()
  125. self._handles.clear()