auth.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from __future__ import annotations
  2. import abc
  3. import dataclasses
  4. import pathlib
  5. import requests
  6. import requests.auth
  7. from typing_extensions import final, override
  8. from wandb.errors import AuthenticationError
  9. from wandb.sdk.lib import credentials
  10. from . import validation
  11. from .host_url import HostUrl
  12. # We use an abstract base class instead of a union because
  13. # (1) All Auth subtypes have a 'host' property and a safe repr
  14. # (2) Auth should be treated as an open union, meaning typecheckers should
  15. # not consider any list of isinstance() checks exhaustive and should
  16. # always require a fallback case
  17. class Auth(abc.ABC):
  18. """Credentials that give access to a W&B server."""
  19. @abc.abstractmethod
  20. def __init__(self, *, host: str | HostUrl) -> None:
  21. if isinstance(host, str):
  22. host = HostUrl(host)
  23. self._host = host
  24. @property
  25. def host(self) -> HostUrl:
  26. """The W&B server for which the credentials are valid."""
  27. return self._host
  28. @abc.abstractmethod
  29. def as_requests_auth(self) -> requests.auth.AuthBase:
  30. """Return a requests-compatible auth handler for this credential.
  31. Returns a callable that implements the requests library's AuthBase
  32. interface. This can be passed directly to requests.Session.auth for
  33. automatic authentication on each request.
  34. For token-based auth (e.g., identity tokens), the returned handler
  35. will automatically refresh expired tokens on each request.
  36. Returns:
  37. A requests.auth.AuthBase instance that applies this credential
  38. to HTTP requests.
  39. """
  40. @final
  41. @override
  42. def __repr__(self) -> str:
  43. return f"<{type(self).__name__} host={self.host.url!r}>"
  44. @final
  45. @override
  46. def __str__(self) -> str:
  47. return repr(self)
  48. @final
  49. class AuthApiKey(Auth):
  50. """An API key for connecting to a W&B server."""
  51. @override
  52. def __init__(self, *, host: str | HostUrl, api_key: str) -> None:
  53. """Initialize AuthApiKey.
  54. Args:
  55. host: The W&B server URL.
  56. api_key: The API key.
  57. Raises:
  58. ValueError: If the host is invalid.
  59. AuthenticationError: If the API key is in an invalid format.
  60. """
  61. super().__init__(host=host)
  62. if problems := validation.check_api_key(api_key):
  63. raise AuthenticationError(problems)
  64. self._api_key = api_key
  65. @property
  66. def api_key(self) -> str:
  67. """The API key."""
  68. return self._api_key
  69. @override
  70. def as_requests_auth(self) -> requests.auth.AuthBase:
  71. """Return a requests auth handler using HTTP Basic Auth.
  72. Returns:
  73. An auth handler that sets the Authorization header with
  74. Basic auth using "api" as the username and the API key
  75. as the password.
  76. """
  77. return requests.auth.HTTPBasicAuth("api", self._api_key)
  78. class _IdentityTokenAuth(requests.auth.AuthBase):
  79. """Requests auth handler for identity token (JWT) authentication."""
  80. def __init__(self, auth: AuthIdentityTokenFile) -> None:
  81. self._auth = auth
  82. def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
  83. token = self._auth.fetch_access_token()
  84. r.headers["Authorization"] = f"Bearer {token}"
  85. return r
  86. @final
  87. class AuthIdentityTokenFile(Auth):
  88. """A path to a file storing a JWT with OIDC credentials."""
  89. @override
  90. def __init__(
  91. self,
  92. *,
  93. host: str | HostUrl,
  94. path: str,
  95. credentials_file: str,
  96. ) -> None:
  97. """Initialize AuthIdentityTokenFile.
  98. Args:
  99. host: The W&B server URL.
  100. path: Path to the identity token file containing a JWT.
  101. credentials_file: Path to the credentials file for caching access tokens.
  102. """
  103. super().__init__(host=host)
  104. self._identity_token_file = pathlib.Path(path)
  105. self._credentials_path = pathlib.Path(credentials_file)
  106. @property
  107. def path(self) -> pathlib.Path:
  108. """Path to a file storing a JWT identity token."""
  109. return self._identity_token_file
  110. @property
  111. def credentials_path(self) -> pathlib.Path:
  112. """Path to the credentials file for caching access tokens."""
  113. return self._credentials_path
  114. def fetch_access_token(self) -> str:
  115. """Fetch an access token for authenticating with the W&B server.
  116. Retrieves a valid access token from the credentials file. If no token
  117. exists or the existing token has expired, exchanges the identity token
  118. (JWT) from the configured token file for a new access token from the
  119. server and caches it in the credentials file.
  120. Returns:
  121. A valid access token string that can be used for Bearer authentication
  122. with the W&B API.
  123. Raises:
  124. FileNotFoundError: If the identity token file does not exist.
  125. OSError: If there is an error reading the identity token file.
  126. AuthenticationError: If the server rejects the identity token or
  127. fails to provide an access token.
  128. """
  129. base_url = str(self.host.url)
  130. return credentials.access_token(base_url, self.path, self.credentials_path)
  131. @override
  132. def as_requests_auth(self) -> requests.auth.AuthBase:
  133. """Return a requests auth handler using Bearer token authentication.
  134. The returned handler calls fetch_access_token() on each request,
  135. ensuring that expired tokens are automatically refreshed.
  136. Returns:
  137. An auth handler that sets the Authorization header with
  138. a Bearer token fetched (and refreshed as needed) from the
  139. identity token file.
  140. """
  141. return _IdentityTokenAuth(self)
  142. @dataclasses.dataclass(frozen=True)
  143. class AuthWithSource:
  144. """Credentials with information about where they came from."""
  145. auth: Auth
  146. source: str
  147. """A file path or environment variable."""