| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- from __future__ import annotations
- import abc
- import dataclasses
- import pathlib
- import requests
- import requests.auth
- from typing_extensions import final, override
- from wandb.errors import AuthenticationError
- from wandb.sdk.lib import credentials
- from . import validation
- from .host_url import HostUrl
- # We use an abstract base class instead of a union because
- # (1) All Auth subtypes have a 'host' property and a safe repr
- # (2) Auth should be treated as an open union, meaning typecheckers should
- # not consider any list of isinstance() checks exhaustive and should
- # always require a fallback case
- class Auth(abc.ABC):
- """Credentials that give access to a W&B server."""
- @abc.abstractmethod
- def __init__(self, *, host: str | HostUrl) -> None:
- if isinstance(host, str):
- host = HostUrl(host)
- self._host = host
- @property
- def host(self) -> HostUrl:
- """The W&B server for which the credentials are valid."""
- return self._host
- @abc.abstractmethod
- def as_requests_auth(self) -> requests.auth.AuthBase:
- """Return a requests-compatible auth handler for this credential.
- Returns a callable that implements the requests library's AuthBase
- interface. This can be passed directly to requests.Session.auth for
- automatic authentication on each request.
- For token-based auth (e.g., identity tokens), the returned handler
- will automatically refresh expired tokens on each request.
- Returns:
- A requests.auth.AuthBase instance that applies this credential
- to HTTP requests.
- """
- @final
- @override
- def __repr__(self) -> str:
- return f"<{type(self).__name__} host={self.host.url!r}>"
- @final
- @override
- def __str__(self) -> str:
- return repr(self)
- @final
- class AuthApiKey(Auth):
- """An API key for connecting to a W&B server."""
- @override
- def __init__(self, *, host: str | HostUrl, api_key: str) -> None:
- """Initialize AuthApiKey.
- Args:
- host: The W&B server URL.
- api_key: The API key.
- Raises:
- ValueError: If the host is invalid.
- AuthenticationError: If the API key is in an invalid format.
- """
- super().__init__(host=host)
- if problems := validation.check_api_key(api_key):
- raise AuthenticationError(problems)
- self._api_key = api_key
- @property
- def api_key(self) -> str:
- """The API key."""
- return self._api_key
- @override
- def as_requests_auth(self) -> requests.auth.AuthBase:
- """Return a requests auth handler using HTTP Basic Auth.
- Returns:
- An auth handler that sets the Authorization header with
- Basic auth using "api" as the username and the API key
- as the password.
- """
- return requests.auth.HTTPBasicAuth("api", self._api_key)
- class _IdentityTokenAuth(requests.auth.AuthBase):
- """Requests auth handler for identity token (JWT) authentication."""
- def __init__(self, auth: AuthIdentityTokenFile) -> None:
- self._auth = auth
- def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
- token = self._auth.fetch_access_token()
- r.headers["Authorization"] = f"Bearer {token}"
- return r
- @final
- class AuthIdentityTokenFile(Auth):
- """A path to a file storing a JWT with OIDC credentials."""
- @override
- def __init__(
- self,
- *,
- host: str | HostUrl,
- path: str,
- credentials_file: str,
- ) -> None:
- """Initialize AuthIdentityTokenFile.
- Args:
- host: The W&B server URL.
- path: Path to the identity token file containing a JWT.
- credentials_file: Path to the credentials file for caching access tokens.
- """
- super().__init__(host=host)
- self._identity_token_file = pathlib.Path(path)
- self._credentials_path = pathlib.Path(credentials_file)
- @property
- def path(self) -> pathlib.Path:
- """Path to a file storing a JWT identity token."""
- return self._identity_token_file
- @property
- def credentials_path(self) -> pathlib.Path:
- """Path to the credentials file for caching access tokens."""
- return self._credentials_path
- def fetch_access_token(self) -> str:
- """Fetch an access token for authenticating with the W&B server.
- Retrieves a valid access token from the credentials file. If no token
- exists or the existing token has expired, exchanges the identity token
- (JWT) from the configured token file for a new access token from the
- server and caches it in the credentials file.
- Returns:
- A valid access token string that can be used for Bearer authentication
- with the W&B API.
- Raises:
- FileNotFoundError: If the identity token file does not exist.
- OSError: If there is an error reading the identity token file.
- AuthenticationError: If the server rejects the identity token or
- fails to provide an access token.
- """
- base_url = str(self.host.url)
- return credentials.access_token(base_url, self.path, self.credentials_path)
- @override
- def as_requests_auth(self) -> requests.auth.AuthBase:
- """Return a requests auth handler using Bearer token authentication.
- The returned handler calls fetch_access_token() on each request,
- ensuring that expired tokens are automatically refreshed.
- Returns:
- An auth handler that sets the Authorization header with
- a Bearer token fetched (and refreshed as needed) from the
- identity token file.
- """
- return _IdentityTokenAuth(self)
- @dataclasses.dataclass(frozen=True)
- class AuthWithSource:
- """Credentials with information about where they came from."""
- auth: Auth
- source: str
- """A file path or environment variable."""
|