response_handle.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from __future__ import annotations
  2. import asyncio
  3. import logging
  4. import math
  5. from collections.abc import Awaitable
  6. from typing import Callable
  7. from typing_extensions import override
  8. from wandb.proto import wandb_server_pb2 as spb
  9. from wandb.sdk.lib import asyncio_manager
  10. from .mailbox_handle import HandleAbandonedError, MailboxHandle, ServerResponseError
  11. _logger = logging.getLogger(__name__)
  12. class MailboxResponseHandle(MailboxHandle[spb.ServerResponse]):
  13. """A general handle for any ServerResponse."""
  14. def __init__(
  15. self,
  16. address: str,
  17. *,
  18. asyncer: asyncio_manager.AsyncioManager,
  19. cancel: Callable[[str], Awaitable[None]],
  20. ) -> None:
  21. super().__init__(asyncer)
  22. self._address = address
  23. self._cancel_fn = cancel
  24. self._abandoned = False
  25. self._response: spb.ServerResponse | None = None
  26. # Initialized on first use in the asyncio thread.
  27. self._done_event: asyncio.Event | None = None
  28. async def deliver(self, response: spb.ServerResponse) -> None:
  29. if self._abandoned:
  30. return
  31. if self._response:
  32. raise ValueError(
  33. f"A response has already been delivered to {self._address}."
  34. )
  35. self._response = response
  36. if not self._done_event:
  37. self._done_event = asyncio.Event()
  38. self._done_event.set()
  39. @override
  40. def cancel(self) -> None:
  41. # Cancel on a best-effort basis and ignore exceptions.
  42. async def impl() -> None:
  43. try:
  44. await self._cancel_fn(self._address)
  45. except Exception:
  46. _logger.exception("Failed to cancel request %r", self._address)
  47. try:
  48. self.abandon()
  49. self.asyncer.run_soon(impl)
  50. except Exception:
  51. _logger.exception(
  52. "Failed to abandon and cancel request %r",
  53. self._address,
  54. )
  55. def abandon(self) -> None:
  56. """Indicate the handle will not receive a response.
  57. This causes any code blocked on `wait_or` or `wait_async` to raise
  58. a `HandleAbandonedError` after a short time.
  59. """
  60. async def impl() -> None:
  61. self._abandoned = True
  62. if not self._done_event:
  63. self._done_event = asyncio.Event()
  64. self._done_event.set()
  65. self.asyncer.run_soon(impl)
  66. @override
  67. def wait_or(self, *, timeout: float | None) -> spb.ServerResponse:
  68. return self.asyncer.run(lambda: self.wait_async(timeout=timeout))
  69. @override
  70. async def wait_async(self, *, timeout: float | None) -> spb.ServerResponse:
  71. if timeout is not None and not math.isfinite(timeout):
  72. raise ValueError("Timeout must be finite or None.")
  73. if not self._done_event:
  74. self._done_event = asyncio.Event()
  75. try:
  76. await asyncio.wait_for(self._done_event.wait(), timeout=timeout)
  77. except (asyncio.TimeoutError, TimeoutError) as e:
  78. if response := self._response_or_error():
  79. return response
  80. elif self._abandoned:
  81. raise HandleAbandonedError()
  82. else:
  83. self.cancel()
  84. raise TimeoutError(
  85. f"Timed out waiting for response on {self._address}"
  86. ) from e
  87. except:
  88. self.cancel()
  89. raise
  90. else:
  91. if response := self._response_or_error():
  92. return response
  93. assert self._abandoned
  94. raise HandleAbandonedError()
  95. def _response_or_error(self) -> spb.ServerResponse | None:
  96. """Returns self._response, raising on ServerErrorResponse."""
  97. if not self._response:
  98. return None
  99. if self._response.HasField("error_response"):
  100. raise ServerResponseError(self._response.error_response.message)
  101. return self._response