config.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from __future__ import annotations
  2. import json
  3. import os
  4. import re
  5. import warnings
  6. from typing import Any
  7. from . import files as sm_files
  8. def is_using_sagemaker() -> bool:
  9. """Returns whether we're in a SageMaker environment."""
  10. return (
  11. os.path.exists(sm_files.SM_PARAM_CONFIG) #
  12. or "SM_TRAINING_ENV" in os.environ
  13. )
  14. def parse_sm_config() -> dict[str, Any]:
  15. """Parses SageMaker configuration.
  16. Returns:
  17. A dictionary of SageMaker config keys/values
  18. or an empty dict if not found.
  19. SM_TRAINING_ENV is a json string of the
  20. training environment variables set by SageMaker
  21. and is only available when running in SageMaker,
  22. but not in local mode.
  23. SM_TRAINING_ENV is set by the SageMaker container and
  24. contains arguments such as hyperparameters
  25. and arguments passed to the training job.
  26. """
  27. conf = {}
  28. if os.path.exists(sm_files.SM_PARAM_CONFIG):
  29. conf["sagemaker_training_job_name"] = os.getenv("TRAINING_JOB_NAME")
  30. # Hyperparameter searches quote configs...
  31. with open(sm_files.SM_PARAM_CONFIG) as fid:
  32. for key, val in json.load(fid).items():
  33. cast = val.strip('"')
  34. if re.match(r"^-?[\d]+$", cast):
  35. cast = int(cast)
  36. elif re.match(r"^-?[.\d]+$", cast):
  37. cast = float(cast)
  38. conf[key] = cast
  39. if env := os.environ.get("SM_TRAINING_ENV"):
  40. try:
  41. conf.update(json.loads(env))
  42. except json.JSONDecodeError:
  43. warnings.warn(
  44. "Failed to parse SM_TRAINING_ENV not valid JSON string",
  45. stacklevel=2,
  46. )
  47. return conf