files.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import json
  2. import os
  3. from typing import Any
  4. import yaml
  5. from ..errors import LaunchError
  6. FILE_OVERRIDE_ENV_VAR = "WANDB_LAUNCH_FILE_OVERRIDES"
  7. class FileOverrides:
  8. """Singleton that read file overrides json from environment variables."""
  9. _instance = None
  10. def __new__(cls):
  11. if cls._instance is None:
  12. cls._instance = object.__new__(cls)
  13. cls._instance.overrides = {}
  14. cls._instance.load()
  15. return cls._instance
  16. def load(self) -> None:
  17. """Load overrides from an environment variable."""
  18. overrides = os.environ.get(FILE_OVERRIDE_ENV_VAR)
  19. if overrides is None and f"{FILE_OVERRIDE_ENV_VAR}_0" in os.environ:
  20. overrides = ""
  21. idx = 0
  22. while f"{FILE_OVERRIDE_ENV_VAR}_{idx}" in os.environ:
  23. overrides += os.environ[f"{FILE_OVERRIDE_ENV_VAR}_{idx}"]
  24. idx += 1
  25. if overrides:
  26. try:
  27. contents = json.loads(overrides)
  28. if not isinstance(contents, dict):
  29. raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
  30. self.overrides = contents
  31. except json.JSONDecodeError:
  32. raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
  33. def config_path_is_valid(path: str) -> None:
  34. """Validate a config file path.
  35. This function checks if a given config file path is valid. A valid path
  36. should meet the following criteria:
  37. - The path must be expressed as a relative path without any upwards path
  38. traversal, e.g. `../config.json`.
  39. - The file specified by the path must exist.
  40. - The file must have a supported extension (`.json`, `.yaml`, or `.yml`).
  41. Args:
  42. path (str): The path to validate.
  43. Raises:
  44. LaunchError: If the path is not valid.
  45. """
  46. if os.path.isabs(path):
  47. raise LaunchError(
  48. f"Invalid config path: {path}. Please provide a relative path."
  49. )
  50. if ".." in path:
  51. raise LaunchError(
  52. f"Invalid config path: {path}. Please provide a relative path "
  53. "without any upward path traversal, e.g. `../config.json`."
  54. )
  55. path = os.path.normpath(path)
  56. if not os.path.exists(path):
  57. raise LaunchError(f"Invalid config path: {path}. File does not exist.")
  58. if not any(path.endswith(ext) for ext in [".json", ".yaml", ".yml"]):
  59. raise LaunchError(
  60. f"Invalid config path: {path}. Only JSON and YAML files are supported."
  61. )
  62. def override_file(path: str) -> None:
  63. """Check for file overrides in the environment and apply them if found."""
  64. file_overrides = FileOverrides()
  65. if path in file_overrides.overrides:
  66. overrides = file_overrides.overrides.get(path)
  67. if overrides is not None:
  68. config = _read_config_file(path)
  69. _update_dict(config, overrides)
  70. _write_config_file(path, config)
  71. def _write_config_file(path: str, config: Any) -> None:
  72. """Write a config file to disk.
  73. Args:
  74. path (str): The path to the config file.
  75. config (Any): The contents of the config file as a Python object.
  76. Raises:
  77. LaunchError: If the file extension is not supported.
  78. """
  79. _, ext = os.path.splitext(path)
  80. if ext == ".json":
  81. with open(path, "w") as f:
  82. json.dump(config, f, indent=2)
  83. elif ext in [".yaml", ".yml"]:
  84. with open(path, "w") as f:
  85. yaml.safe_dump(config, f)
  86. else:
  87. raise LaunchError(f"Unsupported file extension: {ext}")
  88. def _read_config_file(path: str) -> Any:
  89. """Read a config file from disk.
  90. Args:
  91. path (str): The path to the config file.
  92. Returns:
  93. Any: The contents of the config file as a Python object.
  94. """
  95. _, ext = os.path.splitext(path)
  96. if ext == ".json":
  97. with open(
  98. path,
  99. ) as f:
  100. return json.load(f)
  101. elif ext in [".yaml", ".yml"]:
  102. with open(
  103. path,
  104. ) as f:
  105. return yaml.safe_load(f)
  106. else:
  107. raise LaunchError(f"Unsupported file extension: {ext}")
  108. def _update_dict(target: dict, source: dict) -> None:
  109. """Update a dictionary with the contents of another dictionary.
  110. Args:
  111. target (Dict): The dictionary to update.
  112. source (Dict): The dictionary to update from.
  113. """
  114. for key, value in source.items():
  115. if isinstance(value, dict):
  116. if key not in target:
  117. target[key] = {}
  118. _update_dict(target[key], value)
  119. else:
  120. target[key] = value