| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- import json
- import os
- from typing import Any
- import yaml
- from ..errors import LaunchError
- FILE_OVERRIDE_ENV_VAR = "WANDB_LAUNCH_FILE_OVERRIDES"
- class FileOverrides:
- """Singleton that read file overrides json from environment variables."""
- _instance = None
- def __new__(cls):
- if cls._instance is None:
- cls._instance = object.__new__(cls)
- cls._instance.overrides = {}
- cls._instance.load()
- return cls._instance
- def load(self) -> None:
- """Load overrides from an environment variable."""
- overrides = os.environ.get(FILE_OVERRIDE_ENV_VAR)
- if overrides is None and f"{FILE_OVERRIDE_ENV_VAR}_0" in os.environ:
- overrides = ""
- idx = 0
- while f"{FILE_OVERRIDE_ENV_VAR}_{idx}" in os.environ:
- overrides += os.environ[f"{FILE_OVERRIDE_ENV_VAR}_{idx}"]
- idx += 1
- if overrides:
- try:
- contents = json.loads(overrides)
- if not isinstance(contents, dict):
- raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
- self.overrides = contents
- except json.JSONDecodeError:
- raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
- def config_path_is_valid(path: str) -> None:
- """Validate a config file path.
- This function checks if a given config file path is valid. A valid path
- should meet the following criteria:
- - The path must be expressed as a relative path without any upwards path
- traversal, e.g. `../config.json`.
- - The file specified by the path must exist.
- - The file must have a supported extension (`.json`, `.yaml`, or `.yml`).
- Args:
- path (str): The path to validate.
- Raises:
- LaunchError: If the path is not valid.
- """
- if os.path.isabs(path):
- raise LaunchError(
- f"Invalid config path: {path}. Please provide a relative path."
- )
- if ".." in path:
- raise LaunchError(
- f"Invalid config path: {path}. Please provide a relative path "
- "without any upward path traversal, e.g. `../config.json`."
- )
- path = os.path.normpath(path)
- if not os.path.exists(path):
- raise LaunchError(f"Invalid config path: {path}. File does not exist.")
- if not any(path.endswith(ext) for ext in [".json", ".yaml", ".yml"]):
- raise LaunchError(
- f"Invalid config path: {path}. Only JSON and YAML files are supported."
- )
- def override_file(path: str) -> None:
- """Check for file overrides in the environment and apply them if found."""
- file_overrides = FileOverrides()
- if path in file_overrides.overrides:
- overrides = file_overrides.overrides.get(path)
- if overrides is not None:
- config = _read_config_file(path)
- _update_dict(config, overrides)
- _write_config_file(path, config)
- def _write_config_file(path: str, config: Any) -> None:
- """Write a config file to disk.
- Args:
- path (str): The path to the config file.
- config (Any): The contents of the config file as a Python object.
- Raises:
- LaunchError: If the file extension is not supported.
- """
- _, ext = os.path.splitext(path)
- if ext == ".json":
- with open(path, "w") as f:
- json.dump(config, f, indent=2)
- elif ext in [".yaml", ".yml"]:
- with open(path, "w") as f:
- yaml.safe_dump(config, f)
- else:
- raise LaunchError(f"Unsupported file extension: {ext}")
- def _read_config_file(path: str) -> Any:
- """Read a config file from disk.
- Args:
- path (str): The path to the config file.
- Returns:
- Any: The contents of the config file as a Python object.
- """
- _, ext = os.path.splitext(path)
- if ext == ".json":
- with open(
- path,
- ) as f:
- return json.load(f)
- elif ext in [".yaml", ".yml"]:
- with open(
- path,
- ) as f:
- return yaml.safe_load(f)
- else:
- raise LaunchError(f"Unsupported file extension: {ext}")
- def _update_dict(target: dict, source: dict) -> None:
- """Update a dictionary with the contents of another dictionary.
- Args:
- target (Dict): The dictionary to update.
- source (Dict): The dictionary to update from.
- """
- for key, value in source.items():
- if isinstance(value, dict):
- if key not in target:
- target[key] = {}
- _update_dict(target[key], value)
- else:
- target[key] = value
|