service_connection.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. from __future__ import annotations
  2. import atexit
  3. import pathlib
  4. from typing import Callable
  5. from wandb.proto import wandb_api_pb2, wandb_settings_pb2, wandb_sync_pb2
  6. from wandb.proto import wandb_server_pb2 as spb
  7. from wandb.sdk import wandb_settings
  8. from wandb.sdk.interface.interface import InterfaceBase
  9. from wandb.sdk.interface.interface_sock import InterfaceSock
  10. from wandb.sdk.lib import asyncio_manager
  11. from wandb.sdk.lib.exit_hooks import ExitHooks
  12. from wandb.sdk.lib.service.service_client import ServiceClient
  13. from wandb.sdk.mailbox import HandleAbandonedError, MailboxClosedError
  14. from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
  15. from . import service_process, service_token
  16. class WandbAttachFailedError(Exception):
  17. """Failed to attach to a run."""
  18. class WandbApiFailedError(Exception):
  19. """Failed to execute an API request to wandb-core."""
  20. def __init__(
  21. self,
  22. message: str,
  23. response: wandb_api_pb2.ApiErrorResponse | None = None,
  24. ):
  25. super().__init__(message)
  26. self.response = response
  27. def connect_to_service(
  28. asyncer: asyncio_manager.AsyncioManager,
  29. settings: wandb_settings.Settings,
  30. ) -> ServiceConnection:
  31. """Connect to the service process, starting one up if necessary."""
  32. token = service_token.from_env()
  33. if token:
  34. return ServiceConnection(
  35. asyncer=asyncer,
  36. client=token.connect(asyncer=asyncer),
  37. proc=None,
  38. )
  39. else:
  40. return _start_and_connect_service(asyncer, settings)
  41. def _start_and_connect_service(
  42. asyncer: asyncio_manager.AsyncioManager,
  43. settings: wandb_settings.Settings,
  44. ) -> ServiceConnection:
  45. """Start a service process and returns a connection to it.
  46. An atexit hook is registered to tear down the service process and wait for
  47. it to complete. The hook does not run in processes started using the
  48. multiprocessing module.
  49. """
  50. proc = service_process.start(settings)
  51. client = proc.token.connect(asyncer=asyncer)
  52. proc.token.save_to_env()
  53. hooks = ExitHooks()
  54. hooks.hook()
  55. def teardown_atexit():
  56. conn.teardown(hooks.exit_code)
  57. conn = ServiceConnection(
  58. asyncer=asyncer,
  59. client=client,
  60. proc=proc,
  61. cleanup=lambda: atexit.unregister(teardown_atexit),
  62. )
  63. atexit.register(teardown_atexit)
  64. return conn
  65. class ServiceConnection:
  66. """A connection to the W&B internal service process.
  67. None of the synchronous methods may be called in an asyncio context.
  68. """
  69. def __init__(
  70. self,
  71. asyncer: asyncio_manager.AsyncioManager,
  72. client: ServiceClient,
  73. proc: service_process.ServiceProcess | None,
  74. cleanup: Callable[[], None] | None = None,
  75. ):
  76. """Returns a new ServiceConnection.
  77. Args:
  78. asyncer: An asyncio runner.
  79. client: A client for communicating with the service over a socket.
  80. proc: The service process if we own it, or None otherwise.
  81. cleanup: A callback to run on teardown before doing anything.
  82. """
  83. self._asyncer = asyncer
  84. self._client = client
  85. self._proc = proc
  86. self._torn_down = False
  87. self._cleanup = cleanup
  88. @property
  89. def owns_service(self) -> bool:
  90. """Whether this connection owns the wandb-core service process.
  91. If True, teardown() will shut down the service process. If False, the
  92. service process is externally managed (e.g. via WANDB_SERVICE).
  93. """
  94. return self._proc is not None
  95. def make_interface(self, stream_id: str) -> InterfaceBase:
  96. """Returns an interface for communicating with the service."""
  97. return InterfaceSock(
  98. self._asyncer,
  99. self._client,
  100. stream_id=stream_id,
  101. )
  102. async def init_sync(
  103. self,
  104. paths: set[pathlib.Path],
  105. settings: wandb_settings.Settings,
  106. *,
  107. cwd: pathlib.Path | None,
  108. live: bool,
  109. entity: str,
  110. project: str,
  111. run_id: str,
  112. job_type: str,
  113. tag_replacements: dict[str, str],
  114. ) -> MailboxHandle[wandb_sync_pb2.ServerInitSyncResponse]:
  115. """Send a ServerInitSyncRequest."""
  116. init_sync = wandb_sync_pb2.ServerInitSyncRequest(
  117. path=(str(path) for path in paths),
  118. cwd=str(cwd) if cwd else "",
  119. live=live,
  120. settings=settings.to_proto(),
  121. new_entity=entity,
  122. new_project=project,
  123. new_run_id=run_id,
  124. new_job_type=job_type,
  125. tag_replacements=tag_replacements,
  126. )
  127. request = spb.ServerRequest(init_sync=init_sync)
  128. handle = await self._client.deliver(request)
  129. return handle.map(lambda r: r.init_sync_response)
  130. async def sync(
  131. self,
  132. id: str,
  133. *,
  134. parallelism: int,
  135. ) -> MailboxHandle[wandb_sync_pb2.ServerSyncResponse]:
  136. """Send a ServerSyncRequest."""
  137. sync = wandb_sync_pb2.ServerSyncRequest(id=id, parallelism=parallelism)
  138. request = spb.ServerRequest(sync=sync)
  139. handle = await self._client.deliver(request)
  140. return handle.map(lambda r: r.sync_response)
  141. def api_init_request(
  142. self,
  143. settings: wandb_settings_pb2.Settings,
  144. ) -> wandb_api_pb2.ServerApiInitResponse:
  145. """Tells wandb-core to initialize resources for handling API requests."""
  146. api_init_request = wandb_api_pb2.ServerApiInitRequest(
  147. settings=settings,
  148. )
  149. request = spb.ServerRequest(api_init_request=api_init_request)
  150. handle = self._asyncer.run(lambda: self._client.deliver(request))
  151. try:
  152. response = handle.wait_or(timeout=10)
  153. except (MailboxClosedError, HandleAbandonedError):
  154. raise WandbApiFailedError(
  155. "Failed to initialize API resources:"
  156. + " the service process is not running.",
  157. ) from None
  158. except TimeoutError:
  159. raise WandbApiFailedError(
  160. "Failed to initialize API resources:"
  161. + " the service process is busy and did not respond in time.",
  162. ) from None
  163. if response.api_init_response.error_message:
  164. raise WandbApiFailedError(response.api_init_response.error_message)
  165. return response.api_init_response
  166. def api_cleanup_request(self, api_id: str) -> None:
  167. """Tells wandb-core to cleanup API resources."""
  168. api_cleanup_request = wandb_api_pb2.ServerApiCleanupRequest(
  169. api_id=api_id,
  170. )
  171. request = spb.ServerRequest(api_cleanup_request=api_cleanup_request)
  172. self._asyncer.run(lambda: self._client.publish(request))
  173. async def sync_status(
  174. self,
  175. id: str,
  176. ) -> MailboxHandle[wandb_sync_pb2.ServerSyncStatusResponse]:
  177. """Send a ServerSyncStatusRequest."""
  178. sync_status = wandb_sync_pb2.ServerSyncStatusRequest(id=id)
  179. request = spb.ServerRequest(sync_status=sync_status)
  180. handle = await self._client.deliver(request)
  181. return handle.map(lambda r: r.sync_status_response)
  182. async def api_request_async(
  183. self,
  184. api_request: wandb_api_pb2.ApiRequest,
  185. ) -> MailboxHandle[wandb_api_pb2.ApiResponse]:
  186. """Send an ApiRequest and return a handle to the response."""
  187. request = spb.ServerRequest()
  188. request.api_request.CopyFrom(api_request)
  189. handle = await self._client.deliver(request)
  190. return handle.map(lambda r: r.api_response)
  191. def api_request(
  192. self,
  193. api_request: wandb_api_pb2.ApiRequest,
  194. timeout: float | None = None,
  195. ) -> wandb_api_pb2.ApiResponse:
  196. """Send an ApiRequest and wait for a response."""
  197. handle = self._asyncer.run(lambda: self.api_request_async(api_request))
  198. try:
  199. response = handle.wait_or(timeout=timeout)
  200. except (MailboxClosedError, HandleAbandonedError):
  201. raise WandbApiFailedError(
  202. "Failed to initialize API resources:"
  203. + " the service process is not running.",
  204. ) from None
  205. except TimeoutError:
  206. raise WandbApiFailedError(
  207. "Failed to initialize API resources:"
  208. + " the service process is busy and did not respond in time.",
  209. ) from None
  210. if response.HasField("api_error_response"):
  211. raise WandbApiFailedError(
  212. response.api_error_response.message,
  213. response.api_error_response,
  214. )
  215. return response
  216. def api_publish(self, api_request: wandb_api_pb2.ApiRequest) -> None:
  217. """Publish an ApiRequest without waiting for a response."""
  218. request = spb.ServerRequest()
  219. request.api_request.CopyFrom(api_request)
  220. self._asyncer.run(lambda: self._client.publish(request))
  221. def inform_init(
  222. self,
  223. settings: wandb_settings_pb2.Settings,
  224. run_id: str,
  225. ) -> None:
  226. """Send an init request to the service."""
  227. request = spb.ServerInformInitRequest()
  228. request.settings.CopyFrom(settings)
  229. request._info.stream_id = run_id
  230. self._asyncer.run(
  231. lambda: self._client.publish(spb.ServerRequest(inform_init=request))
  232. )
  233. def inform_finish(self, run_id: str) -> None:
  234. """Send an finish request to the service."""
  235. request = spb.ServerInformFinishRequest()
  236. request._info.stream_id = run_id
  237. self._asyncer.run(
  238. lambda: self._client.publish(spb.ServerRequest(inform_finish=request))
  239. )
  240. def inform_attach(
  241. self,
  242. attach_id: str,
  243. ) -> wandb_settings_pb2.Settings:
  244. """Send an attach request to the service.
  245. Raises a WandbAttachFailedError if attaching is not possible.
  246. """
  247. request = spb.ServerRequest()
  248. request.inform_attach._info.stream_id = attach_id
  249. try:
  250. handle = self._asyncer.run(lambda: self._client.deliver(request))
  251. response = handle.wait_or(timeout=10)
  252. except (MailboxClosedError, HandleAbandonedError):
  253. raise WandbAttachFailedError(
  254. "Failed to attach: the service process is not running.",
  255. ) from None
  256. except TimeoutError:
  257. raise WandbAttachFailedError(
  258. "Failed to attach because the run does not belong to"
  259. + " the current service process, or because the service"
  260. + " process is busy (unlikely)."
  261. ) from None
  262. else:
  263. return response.inform_attach_response.settings
  264. def teardown(self, exit_code: int) -> int | None:
  265. """Close the connection.
  266. Stop reading responses on the connection, and if this connection owns
  267. the service process, send a teardown message and wait for it to shut
  268. down.
  269. This may only be called once.
  270. Returns:
  271. The exit code of the service process, or None if the process was
  272. not owned by this connection.
  273. """
  274. if self._torn_down:
  275. raise AssertionError("Already torn down.")
  276. self._torn_down = True
  277. if self._cleanup:
  278. self._cleanup()
  279. if not self._proc:
  280. return None
  281. # Clear the service token to prevent new connections to the process.
  282. service_token.clear_service_in_env()
  283. async def publish_teardown_and_close() -> None:
  284. await self._client.publish(
  285. spb.ServerRequest(
  286. inform_teardown=spb.ServerInformTeardownRequest(
  287. exit_code=exit_code,
  288. )
  289. ),
  290. )
  291. await self._client.close()
  292. self._asyncer.run(publish_teardown_and_close)
  293. return self._proc.join()