interface_sock.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from __future__ import annotations
  2. import logging
  3. from typing import TYPE_CHECKING
  4. from typing_extensions import override
  5. from wandb.proto import wandb_server_pb2 as spb
  6. from wandb.sdk.lib import asyncio_manager
  7. from .interface_shared import InterfaceShared
  8. if TYPE_CHECKING:
  9. from wandb.proto import wandb_internal_pb2 as pb
  10. from wandb.sdk.lib.service.service_client import ServiceClient
  11. from wandb.sdk.mailbox import MailboxHandle
  12. logger = logging.getLogger("wandb")
  13. class InterfaceSock(InterfaceShared):
  14. def __init__(
  15. self,
  16. asyncer: asyncio_manager.AsyncioManager,
  17. client: ServiceClient,
  18. stream_id: str,
  19. ) -> None:
  20. super().__init__()
  21. self._asyncer = asyncer
  22. self._client = client
  23. self._stream_id = stream_id
  24. def _assign(self, record: pb.Record) -> None:
  25. record._info.stream_id = self._stream_id
  26. @override
  27. def _publish(self, record: pb.Record, *, nowait: bool = False) -> None:
  28. self._assign(record)
  29. request = spb.ServerRequest()
  30. request.record_publish.CopyFrom(record)
  31. if nowait:
  32. self._asyncer.run_soon(lambda: self._client.publish(request))
  33. else:
  34. self._asyncer.run(lambda: self._client.publish(request))
  35. @override
  36. def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]:
  37. return self._asyncer.run(lambda: self.deliver_async(record))
  38. @override
  39. async def deliver_async(self, record: pb.Record) -> MailboxHandle[pb.Result]:
  40. self._assign(record)
  41. request = spb.ServerRequest()
  42. request.record_publish.CopyFrom(record)
  43. handle = await self._client.deliver(request)
  44. return handle.map(lambda response: response.result_communicate)