| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- from __future__ import annotations
- import os
- import secrets
- import socket
- import string
- import wandb
- from . import config
- from . import files as sm_files
- def set_run_id(run_settings: wandb.Settings) -> bool:
- """Set a run ID and group when using SageMaker.
- Returns whether the ID and group were updated.
- """
- # Added in https://github.com/wandb/wandb/pull/3290.
- #
- # Prevents SageMaker from overriding the run ID configured
- # in environment variables. Note, however, that it will still
- # override a run ID passed explicitly to `wandb.init()`.
- if os.getenv("WANDB_RUN_ID"):
- return False
- run_group = os.getenv("TRAINING_JOB_NAME")
- if not run_group:
- return False
- alphanumeric = string.ascii_lowercase + string.digits
- random = "".join(secrets.choice(alphanumeric) for _ in range(6))
- host = os.getenv("CURRENT_HOST", socket.gethostname())
- run_settings.run_id = f"{run_group}-{random}-{host}"
- run_settings.run_group = run_group
- return True
- def set_global_settings(settings: wandb.Settings) -> None:
- """Set global W&B settings based on the SageMaker environment."""
- if env := parse_sm_secrets():
- settings.update_from_env_vars(env)
- # The SageMaker config may contain an API key, in which case it
- # takes precedence over the value in the secrets. It's unclear
- # whether this is by design, or by accident; we keep it for
- # backward compatibility for now.
- sm_config = config.parse_sm_config()
- if api_key := sm_config.get("wandb_api_key"):
- settings.api_key = api_key
- def parse_sm_secrets() -> dict[str, str]:
- """We read our api_key from secrets.env in SageMaker."""
- env_dict = dict()
- # Set secret variables
- if os.path.exists(sm_files.SM_SECRETS):
- for line in open(sm_files.SM_SECRETS):
- key, val = line.strip().split("=", 1)
- env_dict[key] = val
- return env_dict
|