resources.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from __future__ import annotations
  2. import os
  3. import secrets
  4. import socket
  5. import string
  6. import wandb
  7. from . import config
  8. from . import files as sm_files
  9. def set_run_id(run_settings: wandb.Settings) -> bool:
  10. """Set a run ID and group when using SageMaker.
  11. Returns whether the ID and group were updated.
  12. """
  13. # Added in https://github.com/wandb/wandb/pull/3290.
  14. #
  15. # Prevents SageMaker from overriding the run ID configured
  16. # in environment variables. Note, however, that it will still
  17. # override a run ID passed explicitly to `wandb.init()`.
  18. if os.getenv("WANDB_RUN_ID"):
  19. return False
  20. run_group = os.getenv("TRAINING_JOB_NAME")
  21. if not run_group:
  22. return False
  23. alphanumeric = string.ascii_lowercase + string.digits
  24. random = "".join(secrets.choice(alphanumeric) for _ in range(6))
  25. host = os.getenv("CURRENT_HOST", socket.gethostname())
  26. run_settings.run_id = f"{run_group}-{random}-{host}"
  27. run_settings.run_group = run_group
  28. return True
  29. def set_global_settings(settings: wandb.Settings) -> None:
  30. """Set global W&B settings based on the SageMaker environment."""
  31. if env := parse_sm_secrets():
  32. settings.update_from_env_vars(env)
  33. # The SageMaker config may contain an API key, in which case it
  34. # takes precedence over the value in the secrets. It's unclear
  35. # whether this is by design, or by accident; we keep it for
  36. # backward compatibility for now.
  37. sm_config = config.parse_sm_config()
  38. if api_key := sm_config.get("wandb_api_key"):
  39. settings.api_key = api_key
  40. def parse_sm_secrets() -> dict[str, str]:
  41. """We read our api_key from secrets.env in SageMaker."""
  42. env_dict = dict()
  43. # Set secret variables
  44. if os.path.exists(sm_files.SM_SECRETS):
  45. for line in open(sm_files.SM_SECRETS):
  46. key, val = line.strip().split("=", 1)
  47. env_dict[key] = val
  48. return env_dict