from __future__ import annotations import asyncio import logging import math from collections.abc import Awaitable from typing import Callable from typing_extensions import override from wandb.proto import wandb_server_pb2 as spb from wandb.sdk.lib import asyncio_manager from .mailbox_handle import HandleAbandonedError, MailboxHandle, ServerResponseError _logger = logging.getLogger(__name__) class MailboxResponseHandle(MailboxHandle[spb.ServerResponse]): """A general handle for any ServerResponse.""" def __init__( self, address: str, *, asyncer: asyncio_manager.AsyncioManager, cancel: Callable[[str], Awaitable[None]], ) -> None: super().__init__(asyncer) self._address = address self._cancel_fn = cancel self._abandoned = False self._response: spb.ServerResponse | None = None # Initialized on first use in the asyncio thread. self._done_event: asyncio.Event | None = None async def deliver(self, response: spb.ServerResponse) -> None: if self._abandoned: return if self._response: raise ValueError( f"A response has already been delivered to {self._address}." ) self._response = response if not self._done_event: self._done_event = asyncio.Event() self._done_event.set() @override def cancel(self) -> None: # Cancel on a best-effort basis and ignore exceptions. async def impl() -> None: try: await self._cancel_fn(self._address) except Exception: _logger.exception("Failed to cancel request %r", self._address) try: self.abandon() self.asyncer.run_soon(impl) except Exception: _logger.exception( "Failed to abandon and cancel request %r", self._address, ) def abandon(self) -> None: """Indicate the handle will not receive a response. This causes any code blocked on `wait_or` or `wait_async` to raise a `HandleAbandonedError` after a short time. """ async def impl() -> None: self._abandoned = True if not self._done_event: self._done_event = asyncio.Event() self._done_event.set() self.asyncer.run_soon(impl) @override def wait_or(self, *, timeout: float | None) -> spb.ServerResponse: return self.asyncer.run(lambda: self.wait_async(timeout=timeout)) @override async def wait_async(self, *, timeout: float | None) -> spb.ServerResponse: if timeout is not None and not math.isfinite(timeout): raise ValueError("Timeout must be finite or None.") if not self._done_event: self._done_event = asyncio.Event() try: await asyncio.wait_for(self._done_event.wait(), timeout=timeout) except (asyncio.TimeoutError, TimeoutError) as e: if response := self._response_or_error(): return response elif self._abandoned: raise HandleAbandonedError() else: self.cancel() raise TimeoutError( f"Timed out waiting for response on {self._address}" ) from e except: self.cancel() raise else: if response := self._response_or_error(): return response assert self._abandoned raise HandleAbandonedError() def _response_or_error(self) -> spb.ServerResponse | None: """Returns self._response, raising on ServerErrorResponse.""" if not self._response: return None if self._response.HasField("error_response"): raise ServerResponseError(self._response.error_response.message) return self._response