| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- from __future__ import annotations
- import asyncio
- import logging
- import math
- from collections.abc import Awaitable
- from typing import Callable
- from typing_extensions import override
- from wandb.proto import wandb_server_pb2 as spb
- from wandb.sdk.lib import asyncio_manager
- from .mailbox_handle import HandleAbandonedError, MailboxHandle, ServerResponseError
- _logger = logging.getLogger(__name__)
- class MailboxResponseHandle(MailboxHandle[spb.ServerResponse]):
- """A general handle for any ServerResponse."""
- def __init__(
- self,
- address: str,
- *,
- asyncer: asyncio_manager.AsyncioManager,
- cancel: Callable[[str], Awaitable[None]],
- ) -> None:
- super().__init__(asyncer)
- self._address = address
- self._cancel_fn = cancel
- self._abandoned = False
- self._response: spb.ServerResponse | None = None
- # Initialized on first use in the asyncio thread.
- self._done_event: asyncio.Event | None = None
- async def deliver(self, response: spb.ServerResponse) -> None:
- if self._abandoned:
- return
- if self._response:
- raise ValueError(
- f"A response has already been delivered to {self._address}."
- )
- self._response = response
- if not self._done_event:
- self._done_event = asyncio.Event()
- self._done_event.set()
- @override
- def cancel(self) -> None:
- # Cancel on a best-effort basis and ignore exceptions.
- async def impl() -> None:
- try:
- await self._cancel_fn(self._address)
- except Exception:
- _logger.exception("Failed to cancel request %r", self._address)
- try:
- self.abandon()
- self.asyncer.run_soon(impl)
- except Exception:
- _logger.exception(
- "Failed to abandon and cancel request %r",
- self._address,
- )
- def abandon(self) -> None:
- """Indicate the handle will not receive a response.
- This causes any code blocked on `wait_or` or `wait_async` to raise
- a `HandleAbandonedError` after a short time.
- """
- async def impl() -> None:
- self._abandoned = True
- if not self._done_event:
- self._done_event = asyncio.Event()
- self._done_event.set()
- self.asyncer.run_soon(impl)
- @override
- def wait_or(self, *, timeout: float | None) -> spb.ServerResponse:
- return self.asyncer.run(lambda: self.wait_async(timeout=timeout))
- @override
- async def wait_async(self, *, timeout: float | None) -> spb.ServerResponse:
- if timeout is not None and not math.isfinite(timeout):
- raise ValueError("Timeout must be finite or None.")
- if not self._done_event:
- self._done_event = asyncio.Event()
- try:
- await asyncio.wait_for(self._done_event.wait(), timeout=timeout)
- except (asyncio.TimeoutError, TimeoutError) as e:
- if response := self._response_or_error():
- return response
- elif self._abandoned:
- raise HandleAbandonedError()
- else:
- self.cancel()
- raise TimeoutError(
- f"Timed out waiting for response on {self._address}"
- ) from e
- except:
- self.cancel()
- raise
- else:
- if response := self._response_or_error():
- return response
- assert self._abandoned
- raise HandleAbandonedError()
- def _response_or_error(self) -> spb.ServerResponse | None:
- """Returns self._response, raising on ServerErrorResponse."""
- if not self._response:
- return None
- if self._response.HasField("error_response"):
- raise ServerResponseError(self._response.error_response.message)
- return self._response
|