| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import os
- import shutil
- import tempfile
- from contextlib import contextmanager
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Dict, Optional
- from ray._raylet import AuthenticationTokenLoader, Config
- _AUTH_ENV_VARS = ("RAY_AUTH_MODE", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH")
- _DEFAULT_AUTH_TOKEN_RELATIVE_PATH = Path(".ray") / "auth_token"
- def reset_auth_token_state() -> None:
- """Reset authentication token and AUTH_MODE ray config."""
- AuthenticationTokenLoader.instance().reset_cache()
- Config.initialize("")
- def set_auth_mode(mode: str) -> None:
- """Set the authentication mode environment variable."""
- os.environ["RAY_AUTH_MODE"] = mode
- def set_env_auth_token(token: str) -> None:
- """Configure the authentication token via environment variable."""
- os.environ["RAY_AUTH_TOKEN"] = token
- os.environ.pop("RAY_AUTH_TOKEN_PATH", None)
- def set_auth_token_path(token: str, path: Path) -> None:
- """Write the authentication token to a specific path and point the loader to it."""
- token_path = Path(path)
- if token is not None:
- token_path.parent.mkdir(parents=True, exist_ok=True)
- token_path.write_text(token)
- os.environ["RAY_AUTH_TOKEN_PATH"] = str(token_path)
- os.environ.pop("RAY_AUTH_TOKEN", None)
- def set_default_auth_token(token: str) -> Path:
- """Write the authentication token to the default ~/.ray/auth_token location."""
- default_path = Path.home() / _DEFAULT_AUTH_TOKEN_RELATIVE_PATH
- default_path.parent.mkdir(parents=True, exist_ok=True)
- default_path.write_text(token)
- return default_path
- def clear_auth_token_sources(remove_default: bool = False) -> None:
- """Clear authentication-related environment variables and optional default token file."""
- for var in ("RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"):
- os.environ.pop(var, None)
- if remove_default:
- default_path = Path.home() / _DEFAULT_AUTH_TOKEN_RELATIVE_PATH
- default_path.unlink(missing_ok=True)
- @dataclass
- class AuthenticationEnvSnapshot:
- original_env: Dict[str, Optional[str]]
- original_home: Optional[str]
- home_was_set: bool
- temp_home: Optional[Path]
- default_token_path: Path
- default_token_exists: bool
- default_token_contents: Optional[str]
- @classmethod
- def capture(cls) -> "AuthenticationEnvSnapshot":
- """Capture current authentication-related environment state."""
- original_env = {var: os.environ.get(var) for var in _AUTH_ENV_VARS}
- home_was_set = "HOME" in os.environ
- original_home = os.environ.get("HOME")
- temp_home: Optional[Path] = None
- if not home_was_set:
- # in CI $HOME may not be set which can cause issues with tests related to default auth token file.
- test_tmpdir = os.environ.get("TEST_TMPDIR")
- base_dir = Path(test_tmpdir) if test_tmpdir else Path(tempfile.gettempdir())
- temp_home = base_dir / "ray_test_home"
- temp_home.mkdir(parents=True, exist_ok=True)
- os.environ["HOME"] = str(temp_home)
- default_token_path = Path.home() / _DEFAULT_AUTH_TOKEN_RELATIVE_PATH
- default_token_exists = default_token_path.exists()
- default_token_contents = (
- default_token_path.read_text() if default_token_exists else None
- )
- return cls(
- original_env=original_env,
- original_home=original_home,
- home_was_set=home_was_set,
- temp_home=temp_home,
- default_token_path=default_token_path,
- default_token_exists=default_token_exists,
- default_token_contents=default_token_contents,
- )
- def clear_default_token(self) -> None:
- """Remove the default token file for the current HOME."""
- self.default_token_path.unlink(missing_ok=True)
- def restore(self) -> None:
- """Restore the captured environment, HOME, and default token file state."""
- # delete any custom token files that may have been created during the test
- custom_token_path = os.environ.get("RAY_AUTH_TOKEN_PATH")
- if custom_token_path is not None:
- custom_token_path = Path(custom_token_path)
- if custom_token_path.exists():
- custom_token_path.unlink(missing_ok=True)
- for var, value in self.original_env.items():
- if value is None:
- os.environ.pop(var, None)
- else:
- os.environ[var] = value
- if self.home_was_set:
- if self.original_home is None:
- os.environ.pop("HOME", None)
- else:
- os.environ["HOME"] = self.original_home
- if self.default_token_exists:
- self.default_token_path.parent.mkdir(parents=True, exist_ok=True)
- self.default_token_path.write_text(self.default_token_contents or "")
- else:
- self.default_token_path.unlink(missing_ok=True)
- if not self.home_was_set:
- current_home = os.environ.get("HOME")
- if self.temp_home is not None and current_home == str(self.temp_home):
- os.environ.pop("HOME", None)
- if self.temp_home is not None and self.temp_home.exists():
- shutil.rmtree(self.temp_home, ignore_errors=True)
- @contextmanager
- def authentication_env_guard():
- """Context manager that restores authentication environment state on exit."""
- snapshot = AuthenticationEnvSnapshot.capture()
- try:
- yield snapshot
- finally:
- snapshot.restore()
|