gcp_environment.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """Implementation of the GCP environment for wandb launch."""
  2. from __future__ import annotations
  3. import logging
  4. import os
  5. import subprocess
  6. from wandb.sdk.launch.errors import LaunchError
  7. from wandb.util import get_module
  8. from ..utils import GCS_URI_RE, event_loop_thread_exec
  9. from .abstract import AbstractEnvironment
  10. google = get_module(
  11. "google",
  12. required="Google Cloud Platform support requires the google package. Please"
  13. " install it with `pip install wandb[launch]`.",
  14. )
  15. google.cloud.compute_v1 = get_module(
  16. "google.cloud.compute_v1",
  17. required="Google Cloud Platform support requires the google-cloud-compute package. "
  18. "Please install it with `pip install wandb[launch]`.",
  19. )
  20. google.auth.credentials = get_module(
  21. "google.auth.credentials",
  22. required="Google Cloud Platform support requires google-auth. "
  23. "Please install it with `pip install wandb[launch]`.",
  24. )
  25. google.auth.transport.requests = get_module(
  26. "google.auth.transport.requests",
  27. required="Google Cloud Platform support requires google-auth. "
  28. "Please install it with `pip install wandb[launch]`.",
  29. )
  30. google.api_core.exceptions = get_module(
  31. "google.api_core.exceptions",
  32. required="Google Cloud Platform support requires google-api-core. "
  33. "Please install it with `pip install wandb[launch]`.",
  34. )
  35. google.cloud.storage = get_module(
  36. "google.cloud.storage",
  37. required="Google Cloud Platform support requires google-cloud-storage. "
  38. "Please install it with `pip install wandb[launch].",
  39. )
  40. _logger = logging.getLogger(__name__)
  41. GCP_REGION_ENV_VAR = "GOOGLE_CLOUD_REGION"
  42. class GcpEnvironment(AbstractEnvironment):
  43. """GCP Environment.
  44. Attributes:
  45. region: The GCP region.
  46. """
  47. region: str
  48. def __init__(
  49. self,
  50. region: str,
  51. ) -> None:
  52. """Initialize the GCP environment.
  53. Arguments:
  54. region: The GCP region.
  55. verify: Whether to verify the credentials, region, and project.
  56. Raises:
  57. LaunchError: If verify is True and the environment is not properly
  58. configured.
  59. """
  60. super().__init__()
  61. _logger.info(f"Initializing GcpEnvironment in region {region}")
  62. self.region: str = region
  63. self._project = ""
  64. @classmethod
  65. def from_config(cls, config: dict) -> GcpEnvironment:
  66. """Create a GcpEnvironment from a config dictionary.
  67. Arguments:
  68. config: The config dictionary.
  69. Returns:
  70. GcpEnvironment: The GcpEnvironment.
  71. """
  72. if config.get("type") != "gcp":
  73. raise LaunchError(
  74. f"Could not create GcpEnvironment from config. Expected type 'gcp' "
  75. f"but got '{config.get('type')}'."
  76. )
  77. region = config.get("region")
  78. if not region:
  79. raise LaunchError(
  80. "Could not create GcpEnvironment from config. Missing 'region' field."
  81. )
  82. return cls(region=region)
  83. @classmethod
  84. def from_default(
  85. cls,
  86. ) -> GcpEnvironment:
  87. """Create a GcpEnvironment from the default configuration.
  88. Returns:
  89. GcpEnvironment: The GcpEnvironment.
  90. """
  91. region = get_default_region()
  92. if region is None:
  93. raise LaunchError(
  94. "Could not create GcpEnvironment from user's gcloud configuration. "
  95. "Please set the default region with `gcloud config set compute/region` "
  96. "or set the environment variable {GCP_REGION_ENV_VAR}. "
  97. "Alternatively, you may specify the region explicitly in your "
  98. "wandb launch configuration at `$HOME/.config/wandb/launch-config.yaml`. "
  99. "See https://docs.wandb.ai/platform/launch/run-agent#environments for more information."
  100. )
  101. return cls(region=region)
  102. @property
  103. def project(self) -> str:
  104. """Get the name of the gcp project associated with the credentials.
  105. Returns:
  106. str: The name of the gcp project.
  107. Raises:
  108. LaunchError: If the launch environment cannot be verified.
  109. """
  110. return self._project
  111. async def get_credentials(self) -> google.auth.credentials.Credentials: # type: ignore
  112. """Get the GCP credentials.
  113. Uses google.auth.default() to get the credentials. If the credentials
  114. are invalid, this method will refresh them. If the credentials are
  115. still invalid after refreshing, this method will raise an error.
  116. Returns:
  117. google.auth.credentials.Credentials: The GCP credentials.
  118. Raises:
  119. LaunchError: If the GCP credentials are invalid.
  120. """
  121. _logger.debug("Getting GCP credentials")
  122. # TODO: Figure out a minimal set of scopes.
  123. try:
  124. google_auth_default = event_loop_thread_exec(google.auth.default)
  125. creds, project = await google_auth_default()
  126. if not self._project:
  127. self._project = project
  128. _logger.debug("Refreshing GCP credentials")
  129. await event_loop_thread_exec(creds.refresh)(
  130. google.auth.transport.requests.Request()
  131. )
  132. except google.auth.exceptions.DefaultCredentialsError as e:
  133. raise LaunchError(
  134. "No Google Cloud Platform credentials found. Please run "
  135. "`gcloud auth application-default login` or set the environment "
  136. "variable GOOGLE_APPLICATION_CREDENTIALS to the path of a valid "
  137. "service account key file."
  138. ) from e
  139. except google.auth.exceptions.RefreshError as e:
  140. raise LaunchError(
  141. "Could not refresh Google Cloud Platform credentials. Please run "
  142. "`gcloud auth application-default login` or set the environment "
  143. "variable GOOGLE_APPLICATION_CREDENTIALS to the path of a valid "
  144. "service account key file."
  145. ) from e
  146. if not creds.valid:
  147. raise LaunchError(
  148. "Invalid Google Cloud Platform credentials. Please run "
  149. "`gcloud auth application-default login` or set the environment "
  150. "variable GOOGLE_APPLICATION_CREDENTIALS to the path of a valid "
  151. "service account key file."
  152. )
  153. return creds
  154. async def verify(self) -> None:
  155. """Verify the credentials, region, and project.
  156. Credentials and region are verified by calling get_credentials(). The
  157. region and is verified by calling the compute API.
  158. Raises:
  159. LaunchError: If the credentials, region, or project are invalid.
  160. Returns:
  161. None
  162. """
  163. _logger.debug("Verifying GCP environment")
  164. await self.get_credentials()
  165. async def verify_storage_uri(self, uri: str) -> None:
  166. """Verify that a storage URI is valid.
  167. Arguments:
  168. uri: The storage URI.
  169. Raises:
  170. LaunchError: If the storage URI is invalid.
  171. """
  172. match = GCS_URI_RE.match(uri)
  173. if not match:
  174. raise LaunchError(f"Invalid GCS URI: {uri}")
  175. bucket = match.group(1)
  176. cloud_storage_client = event_loop_thread_exec(google.cloud.storage.Client)
  177. try:
  178. credentials = await self.get_credentials()
  179. storage_client = await cloud_storage_client(credentials=credentials)
  180. bucket = await event_loop_thread_exec(storage_client.get_bucket)(bucket)
  181. except google.api_core.exceptions.GoogleAPICallError as e:
  182. raise LaunchError(
  183. f"Failed verifying storage uri {uri}: bucket {bucket} does not exist."
  184. ) from e
  185. except google.api_core.exceptions.Forbidden as e:
  186. raise LaunchError(
  187. f"Failed verifying storage uri {uri}: bucket {bucket} is not accessible. Please check your permissions and try again."
  188. ) from e
  189. async def upload_file(self, source: str, destination: str) -> None:
  190. """Upload a file to GCS.
  191. Arguments:
  192. source: The path to the local file.
  193. destination: The path to the GCS file.
  194. Raises:
  195. LaunchError: If the file cannot be uploaded.
  196. """
  197. _logger.debug(f"Uploading file {source} to {destination}")
  198. _err_prefix = f"Could not upload file {source} to GCS destination {destination}"
  199. if not os.path.isfile(source):
  200. raise LaunchError(f"{_err_prefix}: File {source} does not exist.")
  201. match = GCS_URI_RE.match(destination)
  202. if not match:
  203. raise LaunchError(f"{_err_prefix}: Invalid GCS URI: {destination}")
  204. bucket = match.group(1)
  205. key = match.group(2).lstrip("/")
  206. google_storage_client = event_loop_thread_exec(google.cloud.storage.Client)
  207. credentials = await self.get_credentials()
  208. try:
  209. storage_client = await google_storage_client(credentials=credentials)
  210. bucket = await event_loop_thread_exec(storage_client.bucket)(bucket)
  211. blob = await event_loop_thread_exec(bucket.blob)(key)
  212. await event_loop_thread_exec(blob.upload_from_filename)(source)
  213. except google.api_core.exceptions.GoogleAPICallError as e:
  214. resp = e.response
  215. assert resp is not None
  216. try:
  217. message = resp.json()["error"]["message"]
  218. except Exception:
  219. message = str(resp)
  220. raise LaunchError(f"{_err_prefix}: {message}") from e
  221. async def upload_dir(self, source: str, destination: str) -> None:
  222. """Upload a directory to GCS.
  223. Arguments:
  224. source: The path to the local directory.
  225. destination: The path to the GCS directory.
  226. Raises:
  227. LaunchError: If the directory cannot be uploaded.
  228. """
  229. _logger.debug(f"Uploading directory {source} to {destination}")
  230. _err_prefix = (
  231. f"Could not upload directory {source} to GCS destination {destination}"
  232. )
  233. if not os.path.isdir(source):
  234. raise LaunchError(f"{_err_prefix}: Directory {source} does not exist.")
  235. match = GCS_URI_RE.match(destination)
  236. if not match:
  237. raise LaunchError(f"{_err_prefix}: Invalid GCS URI: {destination}")
  238. bucket = match.group(1)
  239. key = match.group(2).lstrip("/")
  240. google_storage_client = event_loop_thread_exec(google.cloud.storage.Client)
  241. credentials = await self.get_credentials()
  242. try:
  243. storage_client = await google_storage_client(credentials=credentials)
  244. bucket = await event_loop_thread_exec(storage_client.bucket)(bucket)
  245. for root, _, files in os.walk(source):
  246. for file in files:
  247. local_path = os.path.join(root, file)
  248. gcs_path = os.path.join(
  249. key, os.path.relpath(local_path, source)
  250. ).replace("\\", "/")
  251. blob = await event_loop_thread_exec(bucket.blob)(gcs_path)
  252. await event_loop_thread_exec(blob.upload_from_filename)(local_path)
  253. except google.api_core.exceptions.GoogleAPICallError as e:
  254. resp = e.response
  255. assert resp is not None
  256. try:
  257. message = resp.json()["error"]["message"]
  258. except Exception:
  259. message = str(resp)
  260. raise LaunchError(f"{_err_prefix}: {message}") from e
  261. except Exception as e:
  262. raise LaunchError(f"{_err_prefix}: GCS upload failed: {e}") from e
  263. def get_gcloud_config_value(config_name: str) -> str | None:
  264. """Get a value from gcloud config.
  265. Arguments:
  266. config_name: The name of the config value.
  267. Returns:
  268. str: The config value, or None if the value is not set.
  269. """
  270. try:
  271. output = subprocess.check_output(
  272. ["gcloud", "config", "get-value", config_name], stderr=subprocess.STDOUT
  273. )
  274. value = str(output.decode("utf-8").strip())
  275. if value and "unset" not in value:
  276. return value
  277. return None
  278. except subprocess.CalledProcessError:
  279. return None
  280. def get_default_region() -> str | None:
  281. """Get the default region from gcloud config or environment variables.
  282. Returns:
  283. str: The default region, or None if it cannot be determined.
  284. """
  285. region = get_gcloud_config_value("compute/region")
  286. if not region:
  287. region = os.environ.get(GCP_REGION_ENV_VAR)
  288. return region