authenticate.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. from __future__ import annotations
  2. import os
  3. import threading
  4. from wandb import env
  5. from wandb.errors import AuthenticationError, UsageError, term
  6. from wandb.sdk import wandb_setup
  7. from . import prompt, wbnetrc
  8. from .auth import Auth, AuthApiKey, AuthIdentityTokenFile, AuthWithSource
  9. from .host_url import HostUrl
  10. _session_auth_lock = threading.Lock()
  11. _session_auth: Auth | None = None
  12. def session_credentials(*, host: str | HostUrl) -> Auth | None:
  13. """Returns the configured session credentials.
  14. Returns None if session credentials are configured for a different host.
  15. """
  16. with _session_auth_lock:
  17. if _session_auth and _session_auth.host.is_same_url(host):
  18. return _session_auth
  19. else:
  20. return None
  21. def _locked_set_session_auth(
  22. auth: Auth | None,
  23. *,
  24. update_settings: bool = True,
  25. ) -> None:
  26. """Update session credentials.
  27. Updates the global _session_auth variable and the global settings.
  28. This is a refactoring step to transition away from storing auth in settings.
  29. Args:
  30. update_settings: Defaults to true. If false, skips updating the global
  31. settings (which may cause them to be loaded).
  32. """
  33. global _session_auth
  34. _session_auth = auth
  35. if not update_settings:
  36. return
  37. settings = wandb_setup.singleton().settings
  38. if auth is None:
  39. settings.api_key = None
  40. settings.identity_token_file = None
  41. elif isinstance(auth, AuthApiKey):
  42. settings.api_key = auth.api_key
  43. settings.identity_token_file = None
  44. settings.base_url = str(auth.host)
  45. elif isinstance(auth, AuthIdentityTokenFile):
  46. settings.api_key = None
  47. settings.identity_token_file = str(auth.path)
  48. settings.credentials_file = str(auth.credentials_path)
  49. settings.base_url = str(auth.host)
  50. else:
  51. raise NotImplementedError(str(auth))
  52. def unauthenticate_session(*, update_settings: bool = True) -> Auth | None:
  53. """Clear the session credentials.
  54. Args:
  55. update_settings: Defaults to true. If false, skips updating the global
  56. settings (which may cause them to be loaded).
  57. Returns:
  58. The previous credentials, if any.
  59. """
  60. with _session_auth_lock:
  61. auth = _session_auth
  62. _locked_set_session_auth(None, update_settings=update_settings)
  63. return auth
  64. def authenticate_session(
  65. *,
  66. host: str | HostUrl,
  67. source: str,
  68. no_offline: bool = False,
  69. no_create: bool = False,
  70. input_timeout: float | None = None,
  71. referrer: str = "models",
  72. relogin: bool = False,
  73. ) -> Auth | None:
  74. """Returns or configures the session credentials.
  75. If the session credentials are already configured for the given host,
  76. returns them. Otherwise, uses system credentials or prompts interactively.
  77. The return value is only None if the user selected offline mode in
  78. the interactive prompt.
  79. Args:
  80. host: The W&B server URL.
  81. source: The source to include in printed messages,
  82. like "wandb.init()".
  83. no_offline: Whether to show an offline option in interactive prompts.
  84. no_create: Whether to show a new account option in interactive prompts.
  85. input_timeout: A timeout for interactive prompts to avoid hanging
  86. the process if we incorrectly identify it as interactive.
  87. referrer: Referrer parameter to add to printed URLs for analytics.
  88. relogin: If true, forces an interactive prompt.
  89. Raises:
  90. TimeoutError: If an interactive prompt is shown and input_timeout expires.
  91. AuthenticationError: If credentials are found but have an invalid format.
  92. UsageError: If interactive prompting is needed but unavailable.
  93. """
  94. if not isinstance(host, HostUrl):
  95. host = HostUrl(host)
  96. if not relogin and (auth := session_credentials(host=host)):
  97. return auth
  98. if not relogin and (auth := _use_system_auth(host=host, source=source)):
  99. return auth
  100. try:
  101. return _use_prompted_auth(
  102. host=host,
  103. no_offline=no_offline,
  104. no_create=no_create,
  105. referrer=referrer,
  106. input_timeout=input_timeout,
  107. )
  108. except term.NotATerminalError:
  109. raise UsageError(
  110. "No API key configured. Use `wandb login` to log in."
  111. ) from None
  112. def use_explicit_auth(auth: Auth, *, source: str) -> None:
  113. """Use explicitly given credentials in the session.
  114. Args:
  115. auth: Credentials to use.
  116. source: The source to include in the printed message,
  117. like "wandb.init()".
  118. """
  119. with _session_auth_lock:
  120. if _session_auth == auth:
  121. return
  122. if _session_auth:
  123. term.termwarn(
  124. f"[{source}] Changing session credentials to explicit value"
  125. + f" for {auth.host}."
  126. )
  127. else:
  128. term.termlog(
  129. f"[{source}] Using explicit session credentials for {auth.host}."
  130. )
  131. _locked_set_session_auth(auth)
  132. def _use_system_auth(*, host: HostUrl, source: str) -> Auth | None:
  133. """Load (or reload) session credentials from external sources.
  134. Loads credentials from environment variables or the .netrc file.
  135. If no credentials are found, the session credentials are unchanged.
  136. Args:
  137. host: The W&B server URL.
  138. source: The source to include in the printed message,
  139. like "wandb.init()".
  140. Raises:
  141. AuthenticationError: If a source of credentials is found but has an
  142. invalid format.
  143. Returns:
  144. The new credentials, if any.
  145. """
  146. auth = (
  147. _try_env_auth(host=host) #
  148. or wbnetrc.read_netrc_auth_with_source(host=host)
  149. )
  150. with _session_auth_lock:
  151. if auth:
  152. term.termlog(
  153. f"[{source}] Loaded credentials for {auth.auth.host}"
  154. + f" from {auth.source}."
  155. )
  156. _locked_set_session_auth(auth.auth)
  157. return _session_auth
  158. def _try_env_auth(*, host: HostUrl) -> AuthWithSource | None:
  159. """Returns credentials from environment variables, if set.
  160. Raises an authentication error if an invalid combination of environment
  161. variables is set.
  162. """
  163. api_key = os.getenv(env.API_KEY)
  164. identity_token_file = os.getenv(env.IDENTITY_TOKEN_FILE)
  165. if api_key and identity_token_file:
  166. raise AuthenticationError(
  167. f"Both {env.API_KEY} and {env.IDENTITY_TOKEN_FILE} are set,"
  168. + " which is not allowed."
  169. )
  170. if api_key:
  171. try:
  172. return AuthWithSource(
  173. auth=AuthApiKey(host=host, api_key=api_key),
  174. source=env.API_KEY,
  175. )
  176. except AuthenticationError as e:
  177. raise AuthenticationError(f"{env.API_KEY} invalid: {e}") from None
  178. elif identity_token_file:
  179. return AuthWithSource(
  180. auth=AuthIdentityTokenFile(
  181. host=host,
  182. path=identity_token_file,
  183. credentials_file=wandb_setup.singleton().settings.credentials_file,
  184. ),
  185. source=env.IDENTITY_TOKEN_FILE,
  186. )
  187. return None
  188. def _use_prompted_auth(
  189. *,
  190. host: HostUrl,
  191. no_offline: bool,
  192. no_create: bool,
  193. referrer: str,
  194. input_timeout: float | None = None,
  195. ) -> Auth | None:
  196. """Prompt interactively to set session credentials.
  197. May clear session credentials if the user selects offline mode.
  198. Args:
  199. host: The W&B server URL.
  200. no_offline: If true, do not show an option to skip logging in.
  201. no_create: If true, do not show an option to create a new account.
  202. referrer: Referrer parameter to include in printed URLs for analytics.
  203. input_timeout: How long to wait for user input before timing out.
  204. Raises:
  205. NotATerminalError: If interactive prompting is not possible.
  206. TimeoutError: If input_timeout expires.
  207. """
  208. api_key = prompt.prompt_and_save_api_key(
  209. host=host,
  210. no_offline=no_offline,
  211. no_create=no_create,
  212. referrer=referrer,
  213. input_timeout=input_timeout,
  214. )
  215. with _session_auth_lock:
  216. if api_key:
  217. _locked_set_session_auth(AuthApiKey(host=host, api_key=api_key))
  218. else:
  219. # Offline mode selected.
  220. _locked_set_session_auth(None)
  221. return _session_auth