credentials.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import json
  2. import os
  3. from datetime import datetime, timedelta
  4. from pathlib import Path
  5. from wandb.errors import AuthenticationError
  6. DEFAULT_WANDB_CREDENTIALS_FILE = Path(
  7. os.path.expanduser("~/.config/wandb/credentials.json")
  8. )
  9. _expires_at_fmt = "%Y-%m-%d %H:%M:%S"
  10. def access_token(base_url: str, token_file: Path, credentials_file: Path) -> str:
  11. """Retrieve an access token from the credentials file.
  12. If no access token exists, create a new one by exchanging the identity
  13. token from the token file, and save it to the credentials file.
  14. Args:
  15. base_url (str): The base URL of the server
  16. token_file (pathlib.Path): The path to the file containing the
  17. identity token
  18. credentials_file (pathlib.Path): The path to file used to save
  19. temporary access tokens
  20. Returns:
  21. str: The access token
  22. """
  23. if not credentials_file.exists():
  24. _write_credentials_file(base_url, token_file, credentials_file)
  25. data = _fetch_credentials(base_url, token_file, credentials_file)
  26. return data["access_token"]
  27. def _write_credentials_file(base_url: str, token_file: Path, credentials_file: Path):
  28. """Obtain an access token from the server and write it to the credentials file.
  29. Args:
  30. base_url (str): The base URL of the server
  31. token_file (pathlib.Path): The path to the file containing the
  32. identity token
  33. credentials_file (pathlib.Path): The path to file used to save
  34. temporary access tokens
  35. """
  36. credentials = _create_access_token(base_url, token_file)
  37. data = {"credentials": {base_url: credentials}}
  38. with open(credentials_file, "w") as file:
  39. json.dump(data, file, indent=4)
  40. # Set file permissions to be read/write by the owner only
  41. os.chmod(credentials_file, 0o600)
  42. def _fetch_credentials(base_url: str, token_file: Path, credentials_file: Path) -> dict:
  43. """Fetch the access token from the credentials file.
  44. If the access token has expired, fetch a new one from the server and save it
  45. to the credentials file.
  46. Args:
  47. base_url (str): The base URL of the server
  48. token_file (pathlib.Path): The path to the file containing the
  49. identity token
  50. credentials_file (pathlib.Path): The path to file used to save
  51. temporary access tokens
  52. Returns:
  53. dict: The credentials including the access token.
  54. """
  55. creds = {}
  56. with open(credentials_file) as file:
  57. data = json.load(file)
  58. if "credentials" not in data:
  59. data["credentials"] = {}
  60. if base_url in data["credentials"]:
  61. creds = data["credentials"][base_url]
  62. expires_at = datetime.utcnow()
  63. if "expires_at" in creds:
  64. expires_at = datetime.strptime(creds["expires_at"], _expires_at_fmt)
  65. if expires_at <= datetime.utcnow():
  66. creds = _create_access_token(base_url, token_file)
  67. with open(credentials_file, "w") as file:
  68. data["credentials"][base_url] = creds
  69. json.dump(data, file, indent=4)
  70. return creds
  71. def _create_access_token(base_url: str, token_file: Path) -> dict:
  72. """Exchange an identity token for an access token from the server.
  73. Args:
  74. base_url (str): The base URL of the server.
  75. token_file (pathlib.Path): The path to the file containing the
  76. identity token
  77. Returns:
  78. dict: The access token and its expiration.
  79. Raises:
  80. FileNotFoundError: If the token file is not found.
  81. OSError: If there is an issue reading the token file.
  82. AuthenticationError: If the server fails to provide an access token.
  83. """
  84. import requests
  85. try:
  86. with open(token_file) as file:
  87. token = file.read().strip()
  88. except FileNotFoundError as e:
  89. raise FileNotFoundError(f"Identity token file not found: {token_file}") from e
  90. except OSError as e:
  91. raise OSError(
  92. f"Failed to read the identity token from file: {token_file}"
  93. ) from e
  94. url = f"{base_url}/oidc/token"
  95. data = {
  96. "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
  97. "assertion": token,
  98. }
  99. headers = {"Content-Type": "application/x-www-form-urlencoded"}
  100. response = requests.post(url, data=data, headers=headers)
  101. if response.status_code != 200:
  102. raise AuthenticationError(
  103. f"Failed to retrieve access token: {response.status_code}, {response.text}"
  104. )
  105. resp_json = response.json()
  106. expires_at = datetime.utcnow() + timedelta(seconds=float(resp_json["expires_in"]))
  107. resp_json["expires_at"] = expires_at.strftime(_expires_at_fmt)
  108. del resp_json["expires_in"]
  109. return resp_json