| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- from __future__ import annotations
- import json
- import os
- import re
- import warnings
- from typing import Any
- from . import files as sm_files
- def is_using_sagemaker() -> bool:
- """Returns whether we're in a SageMaker environment."""
- return (
- os.path.exists(sm_files.SM_PARAM_CONFIG) #
- or "SM_TRAINING_ENV" in os.environ
- )
- def parse_sm_config() -> dict[str, Any]:
- """Parses SageMaker configuration.
- Returns:
- A dictionary of SageMaker config keys/values
- or an empty dict if not found.
- SM_TRAINING_ENV is a json string of the
- training environment variables set by SageMaker
- and is only available when running in SageMaker,
- but not in local mode.
- SM_TRAINING_ENV is set by the SageMaker container and
- contains arguments such as hyperparameters
- and arguments passed to the training job.
- """
- conf = {}
- if os.path.exists(sm_files.SM_PARAM_CONFIG):
- conf["sagemaker_training_job_name"] = os.getenv("TRAINING_JOB_NAME")
- # Hyperparameter searches quote configs...
- with open(sm_files.SM_PARAM_CONFIG) as fid:
- for key, val in json.load(fid).items():
- cast = val.strip('"')
- if re.match(r"^-?[\d]+$", cast):
- cast = int(cast)
- elif re.match(r"^-?[.\d]+$", cast):
- cast = float(cast)
- conf[key] = cast
- if env := os.environ.get("SM_TRAINING_ENV"):
- try:
- conf.update(json.loads(env))
- except json.JSONDecodeError:
- warnings.warn(
- "Failed to parse SM_TRAINING_ENV not valid JSON string",
- stacklevel=2,
- )
- return conf
|