| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- from __future__ import annotations
- import abc
- import asyncio
- import os
- import re
- from typing_extensions import final, override
- from wandb import env
- from wandb.sdk.lib import asyncio_manager
- from wandb.sdk.lib.service import ipc_support
- from .service_client import ServiceClient
- _CURRENT_VERSION = "3"
- # Token formats:
- _UNIX_TOKEN_RE = re.compile(rf"{_CURRENT_VERSION}-(\d+)-unix-(.+)")
- _TCP_TOKEN_RE = re.compile(rf"{_CURRENT_VERSION}-(\d+)-tcp-localhost-(\d+)")
- class WandbServiceConnectionError(Exception):
- """Failed to connect to the service process."""
- def clear_service_in_env() -> None:
- """Clear the environment variable that stores the service token."""
- os.environ.pop(env.SERVICE, None)
- def from_env() -> ServiceToken | None:
- """Read the token from environment variables.
- Returns:
- The token if the correct environment variable is set, or None.
- Raises:
- ValueError: If the environment variable is set but cannot be
- parsed.
- """
- token = os.environ.get(env.SERVICE)
- if not token:
- return None
- if unix_token := UnixServiceToken.from_env_string(token):
- return unix_token
- if tcp_token := TCPServiceToken.from_env_string(token):
- return tcp_token
- raise ValueError(f"Failed to parse {env.SERVICE}={token!r}")
- class ServiceToken(abc.ABC):
- """A way of connecting to a running service process."""
- @abc.abstractmethod
- def connect(
- self,
- *,
- asyncer: asyncio_manager.AsyncioManager,
- ) -> ServiceClient:
- """Connect to the service process.
- Args:
- asyncer: A started AsyncioManager for asyncio operations.
- Returns:
- A socket object for communicating with the service.
- Raises:
- WandbServiceConnectionError: on failure to connect.
- """
- def save_to_env(self) -> None:
- """Save the token in this process's environment variables."""
- os.environ[env.SERVICE] = self.env_value
- @property
- def env_value(self) -> str:
- """Value to assign to the WANDB_SERVICE environment variable."""
- return self._as_env_string()
- @abc.abstractmethod
- def _as_env_string(self) -> str:
- """Returns a string representation of this token."""
- @final
- class UnixServiceToken(ServiceToken):
- """Connects to the service using a Unix domain socket."""
- def __init__(self, *, parent_pid: int, path: str) -> None:
- self._parent_pid = parent_pid
- self._path = path
- @override
- def connect(
- self,
- *,
- asyncer: asyncio_manager.AsyncioManager,
- ) -> ServiceClient:
- if not ipc_support.SUPPORTS_UNIX:
- raise WandbServiceConnectionError("AF_UNIX socket not supported")
- try:
- # TODO: This may block indefinitely if the service is unhealthy.
- reader, writer = asyncer.run(
- lambda: asyncio.open_unix_connection(self._path),
- )
- except Exception as e:
- raise WandbServiceConnectionError(
- f"Failed to connect to service on socket {self._path}",
- ) from e
- return ServiceClient(asyncer, reader, writer)
- @override
- def _as_env_string(self) -> str:
- return "-".join(
- (
- _CURRENT_VERSION,
- str(self._parent_pid),
- "unix",
- str(self._path),
- )
- )
- @staticmethod
- def from_env_string(token: str) -> UnixServiceToken | None:
- """Returns a Unix service token parsed from the env var."""
- match = _UNIX_TOKEN_RE.fullmatch(token)
- if not match:
- return None
- parent_pid, path = match.groups()
- return UnixServiceToken(parent_pid=int(parent_pid), path=path)
- @final
- class TCPServiceToken(ServiceToken):
- """Connects to the service using TCP over a localhost socket."""
- def __init__(self, *, parent_pid: int, port: int) -> None:
- self._parent_pid = parent_pid
- self._port = port
- @override
- def connect(
- self,
- *,
- asyncer: asyncio_manager.AsyncioManager,
- ) -> ServiceClient:
- try:
- # TODO: This may block indefinitely if the service is unhealthy.
- reader, writer = asyncer.run(
- lambda: asyncio.open_connection("localhost", self._port),
- )
- except Exception as e:
- raise WandbServiceConnectionError(
- f"Failed to connect to service on port {self._port}",
- ) from e
- return ServiceClient(asyncer, reader, writer)
- @override
- def _as_env_string(self) -> str:
- return "-".join(
- (
- _CURRENT_VERSION,
- str(self._parent_pid),
- "tcp",
- "localhost",
- str(self._port),
- )
- )
- @staticmethod
- def from_env_string(token: str) -> TCPServiceToken | None:
- """Returns a TCP service token parsed from the env var."""
- match = _TCP_TOKEN_RE.fullmatch(token)
- if not match:
- return None
- parent_pid, port = match.groups()
- return TCPServiceToken(parent_pid=int(parent_pid), port=int(port))
|