authentication_test_utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import os
  2. import shutil
  3. import tempfile
  4. from contextlib import contextmanager
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Dict, Optional
  8. from ray._raylet import AuthenticationTokenLoader, Config
  9. _AUTH_ENV_VARS = ("RAY_AUTH_MODE", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH")
  10. _DEFAULT_AUTH_TOKEN_RELATIVE_PATH = Path(".ray") / "auth_token"
  11. def reset_auth_token_state() -> None:
  12. """Reset authentication token and AUTH_MODE ray config."""
  13. AuthenticationTokenLoader.instance().reset_cache()
  14. Config.initialize("")
  15. def set_auth_mode(mode: str) -> None:
  16. """Set the authentication mode environment variable."""
  17. os.environ["RAY_AUTH_MODE"] = mode
  18. def set_env_auth_token(token: str) -> None:
  19. """Configure the authentication token via environment variable."""
  20. os.environ["RAY_AUTH_TOKEN"] = token
  21. os.environ.pop("RAY_AUTH_TOKEN_PATH", None)
  22. def set_auth_token_path(token: str, path: Path) -> None:
  23. """Write the authentication token to a specific path and point the loader to it."""
  24. token_path = Path(path)
  25. if token is not None:
  26. token_path.parent.mkdir(parents=True, exist_ok=True)
  27. token_path.write_text(token)
  28. os.environ["RAY_AUTH_TOKEN_PATH"] = str(token_path)
  29. os.environ.pop("RAY_AUTH_TOKEN", None)
  30. def set_default_auth_token(token: str) -> Path:
  31. """Write the authentication token to the default ~/.ray/auth_token location."""
  32. default_path = Path.home() / _DEFAULT_AUTH_TOKEN_RELATIVE_PATH
  33. default_path.parent.mkdir(parents=True, exist_ok=True)
  34. default_path.write_text(token)
  35. return default_path
  36. def clear_auth_token_sources(remove_default: bool = False) -> None:
  37. """Clear authentication-related environment variables and optional default token file."""
  38. for var in ("RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"):
  39. os.environ.pop(var, None)
  40. if remove_default:
  41. default_path = Path.home() / _DEFAULT_AUTH_TOKEN_RELATIVE_PATH
  42. default_path.unlink(missing_ok=True)
  43. @dataclass
  44. class AuthenticationEnvSnapshot:
  45. original_env: Dict[str, Optional[str]]
  46. original_home: Optional[str]
  47. home_was_set: bool
  48. temp_home: Optional[Path]
  49. default_token_path: Path
  50. default_token_exists: bool
  51. default_token_contents: Optional[str]
  52. @classmethod
  53. def capture(cls) -> "AuthenticationEnvSnapshot":
  54. """Capture current authentication-related environment state."""
  55. original_env = {var: os.environ.get(var) for var in _AUTH_ENV_VARS}
  56. home_was_set = "HOME" in os.environ
  57. original_home = os.environ.get("HOME")
  58. temp_home: Optional[Path] = None
  59. if not home_was_set:
  60. # in CI $HOME may not be set which can cause issues with tests related to default auth token file.
  61. test_tmpdir = os.environ.get("TEST_TMPDIR")
  62. base_dir = Path(test_tmpdir) if test_tmpdir else Path(tempfile.gettempdir())
  63. temp_home = base_dir / "ray_test_home"
  64. temp_home.mkdir(parents=True, exist_ok=True)
  65. os.environ["HOME"] = str(temp_home)
  66. default_token_path = Path.home() / _DEFAULT_AUTH_TOKEN_RELATIVE_PATH
  67. default_token_exists = default_token_path.exists()
  68. default_token_contents = (
  69. default_token_path.read_text() if default_token_exists else None
  70. )
  71. return cls(
  72. original_env=original_env,
  73. original_home=original_home,
  74. home_was_set=home_was_set,
  75. temp_home=temp_home,
  76. default_token_path=default_token_path,
  77. default_token_exists=default_token_exists,
  78. default_token_contents=default_token_contents,
  79. )
  80. def clear_default_token(self) -> None:
  81. """Remove the default token file for the current HOME."""
  82. self.default_token_path.unlink(missing_ok=True)
  83. def restore(self) -> None:
  84. """Restore the captured environment, HOME, and default token file state."""
  85. # delete any custom token files that may have been created during the test
  86. custom_token_path = os.environ.get("RAY_AUTH_TOKEN_PATH")
  87. if custom_token_path is not None:
  88. custom_token_path = Path(custom_token_path)
  89. if custom_token_path.exists():
  90. custom_token_path.unlink(missing_ok=True)
  91. for var, value in self.original_env.items():
  92. if value is None:
  93. os.environ.pop(var, None)
  94. else:
  95. os.environ[var] = value
  96. if self.home_was_set:
  97. if self.original_home is None:
  98. os.environ.pop("HOME", None)
  99. else:
  100. os.environ["HOME"] = self.original_home
  101. if self.default_token_exists:
  102. self.default_token_path.parent.mkdir(parents=True, exist_ok=True)
  103. self.default_token_path.write_text(self.default_token_contents or "")
  104. else:
  105. self.default_token_path.unlink(missing_ok=True)
  106. if not self.home_was_set:
  107. current_home = os.environ.get("HOME")
  108. if self.temp_home is not None and current_home == str(self.temp_home):
  109. os.environ.pop("HOME", None)
  110. if self.temp_home is not None and self.temp_home.exists():
  111. shutil.rmtree(self.temp_home, ignore_errors=True)
  112. @contextmanager
  113. def authentication_env_guard():
  114. """Context manager that restores authentication environment state on exit."""
  115. snapshot = AuthenticationEnvSnapshot.capture()
  116. try:
  117. yield snapshot
  118. finally:
  119. snapshot.restore()