service_token.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from __future__ import annotations
  2. import abc
  3. import asyncio
  4. import os
  5. import re
  6. from typing_extensions import final, override
  7. from wandb import env
  8. from wandb.sdk.lib import asyncio_manager
  9. from wandb.sdk.lib.service import ipc_support
  10. from .service_client import ServiceClient
  11. _CURRENT_VERSION = "3"
  12. # Token formats:
  13. _UNIX_TOKEN_RE = re.compile(rf"{_CURRENT_VERSION}-(\d+)-unix-(.+)")
  14. _TCP_TOKEN_RE = re.compile(rf"{_CURRENT_VERSION}-(\d+)-tcp-localhost-(\d+)")
  15. class WandbServiceConnectionError(Exception):
  16. """Failed to connect to the service process."""
  17. def clear_service_in_env() -> None:
  18. """Clear the environment variable that stores the service token."""
  19. os.environ.pop(env.SERVICE, None)
  20. def from_env() -> ServiceToken | None:
  21. """Read the token from environment variables.
  22. Returns:
  23. The token if the correct environment variable is set, or None.
  24. Raises:
  25. ValueError: If the environment variable is set but cannot be
  26. parsed.
  27. """
  28. token = os.environ.get(env.SERVICE)
  29. if not token:
  30. return None
  31. if unix_token := UnixServiceToken.from_env_string(token):
  32. return unix_token
  33. if tcp_token := TCPServiceToken.from_env_string(token):
  34. return tcp_token
  35. raise ValueError(f"Failed to parse {env.SERVICE}={token!r}")
  36. class ServiceToken(abc.ABC):
  37. """A way of connecting to a running service process."""
  38. @abc.abstractmethod
  39. def connect(
  40. self,
  41. *,
  42. asyncer: asyncio_manager.AsyncioManager,
  43. ) -> ServiceClient:
  44. """Connect to the service process.
  45. Args:
  46. asyncer: A started AsyncioManager for asyncio operations.
  47. Returns:
  48. A socket object for communicating with the service.
  49. Raises:
  50. WandbServiceConnectionError: on failure to connect.
  51. """
  52. def save_to_env(self) -> None:
  53. """Save the token in this process's environment variables."""
  54. os.environ[env.SERVICE] = self.env_value
  55. @property
  56. def env_value(self) -> str:
  57. """Value to assign to the WANDB_SERVICE environment variable."""
  58. return self._as_env_string()
  59. @abc.abstractmethod
  60. def _as_env_string(self) -> str:
  61. """Returns a string representation of this token."""
  62. @final
  63. class UnixServiceToken(ServiceToken):
  64. """Connects to the service using a Unix domain socket."""
  65. def __init__(self, *, parent_pid: int, path: str) -> None:
  66. self._parent_pid = parent_pid
  67. self._path = path
  68. @override
  69. def connect(
  70. self,
  71. *,
  72. asyncer: asyncio_manager.AsyncioManager,
  73. ) -> ServiceClient:
  74. if not ipc_support.SUPPORTS_UNIX:
  75. raise WandbServiceConnectionError("AF_UNIX socket not supported")
  76. try:
  77. # TODO: This may block indefinitely if the service is unhealthy.
  78. reader, writer = asyncer.run(
  79. lambda: asyncio.open_unix_connection(self._path),
  80. )
  81. except Exception as e:
  82. raise WandbServiceConnectionError(
  83. f"Failed to connect to service on socket {self._path}",
  84. ) from e
  85. return ServiceClient(asyncer, reader, writer)
  86. @override
  87. def _as_env_string(self) -> str:
  88. return "-".join(
  89. (
  90. _CURRENT_VERSION,
  91. str(self._parent_pid),
  92. "unix",
  93. str(self._path),
  94. )
  95. )
  96. @staticmethod
  97. def from_env_string(token: str) -> UnixServiceToken | None:
  98. """Returns a Unix service token parsed from the env var."""
  99. match = _UNIX_TOKEN_RE.fullmatch(token)
  100. if not match:
  101. return None
  102. parent_pid, path = match.groups()
  103. return UnixServiceToken(parent_pid=int(parent_pid), path=path)
  104. @final
  105. class TCPServiceToken(ServiceToken):
  106. """Connects to the service using TCP over a localhost socket."""
  107. def __init__(self, *, parent_pid: int, port: int) -> None:
  108. self._parent_pid = parent_pid
  109. self._port = port
  110. @override
  111. def connect(
  112. self,
  113. *,
  114. asyncer: asyncio_manager.AsyncioManager,
  115. ) -> ServiceClient:
  116. try:
  117. # TODO: This may block indefinitely if the service is unhealthy.
  118. reader, writer = asyncer.run(
  119. lambda: asyncio.open_connection("localhost", self._port),
  120. )
  121. except Exception as e:
  122. raise WandbServiceConnectionError(
  123. f"Failed to connect to service on port {self._port}",
  124. ) from e
  125. return ServiceClient(asyncer, reader, writer)
  126. @override
  127. def _as_env_string(self) -> str:
  128. return "-".join(
  129. (
  130. _CURRENT_VERSION,
  131. str(self._parent_pid),
  132. "tcp",
  133. "localhost",
  134. str(self._port),
  135. )
  136. )
  137. @staticmethod
  138. def from_env_string(token: str) -> TCPServiceToken | None:
  139. """Returns a TCP service token parsed from the env var."""
  140. match = _TCP_TOKEN_RE.fullmatch(token)
  141. if not match:
  142. return None
  143. parent_pid, port = match.groups()
  144. return TCPServiceToken(parent_pid=int(parent_pid), port=int(port))