config_manager.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """Manager to read and modify config data in JSON files."""
  2. # Copyright (c) Jupyter Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. from __future__ import annotations
  5. import copy
  6. import errno
  7. import glob
  8. import json
  9. import os
  10. import typing as t
  11. from traitlets.config import LoggingConfigurable
  12. from traitlets.traitlets import Bool, Unicode
  13. StrDict = dict[str, t.Any]
  14. def recursive_update(target: StrDict, new: StrDict) -> None:
  15. """Recursively update one dictionary using another.
  16. None values will delete their keys.
  17. """
  18. for k, v in new.items():
  19. if isinstance(v, dict):
  20. if k not in target:
  21. target[k] = {}
  22. recursive_update(target[k], v)
  23. if not target[k]:
  24. # Prune empty subdicts
  25. del target[k]
  26. elif v is None:
  27. target.pop(k, None)
  28. else:
  29. target[k] = v
  30. def remove_defaults(data: StrDict, defaults: StrDict) -> None:
  31. """Recursively remove items from dict that are already in defaults"""
  32. # copy the iterator, since data will be modified
  33. for key, value in list(data.items()):
  34. if key in defaults:
  35. if isinstance(value, dict):
  36. remove_defaults(data[key], defaults[key])
  37. if not data[key]: # prune empty subdicts
  38. del data[key]
  39. elif value == defaults[key]:
  40. del data[key]
  41. class BaseJSONConfigManager(LoggingConfigurable):
  42. """General JSON config manager
  43. Deals with persisting/storing config in a json file with optionally
  44. default values in a {section_name}.d directory.
  45. """
  46. config_dir = Unicode(".")
  47. read_directory = Bool(True)
  48. def ensure_config_dir_exists(self) -> None:
  49. """Will try to create the config_dir directory."""
  50. try:
  51. os.makedirs(self.config_dir, 0o755)
  52. except OSError as e:
  53. if e.errno != errno.EEXIST:
  54. raise
  55. def file_name(self, section_name: str) -> str:
  56. """Returns the json filename for the section_name: {config_dir}/{section_name}.json"""
  57. return os.path.join(self.config_dir, section_name + ".json")
  58. def directory(self, section_name: str) -> str:
  59. """Returns the directory name for the section name: {config_dir}/{section_name}.d"""
  60. return os.path.join(self.config_dir, section_name + ".d")
  61. def get(self, section_name: str, include_root: bool = True) -> dict[str, t.Any]:
  62. """Retrieve the config data for the specified section.
  63. Returns the data as a dictionary, or an empty dictionary if the file
  64. doesn't exist.
  65. When include_root is False, it will not read the root .json file,
  66. effectively returning the default values.
  67. """
  68. paths = [self.file_name(section_name)] if include_root else []
  69. if self.read_directory:
  70. pattern = os.path.join(self.directory(section_name), "*.json")
  71. # These json files should be processed first so that the
  72. # {section_name}.json take precedence.
  73. # The idea behind this is that installing a Python package may
  74. # put a json file somewhere in the a .d directory, while the
  75. # .json file is probably a user configuration.
  76. paths = sorted(glob.glob(pattern)) + paths
  77. self.log.debug(
  78. "Paths used for configuration of %s: \n\t%s",
  79. section_name,
  80. "\n\t".join(paths),
  81. )
  82. data: dict[str, t.Any] = {}
  83. for path in paths:
  84. if os.path.isfile(path) and os.path.getsize(path):
  85. with open(path, encoding="utf-8") as f:
  86. try:
  87. recursive_update(data, json.load(f))
  88. except json.decoder.JSONDecodeError:
  89. self.log.warning("Invalid JSON in %s, skipping", path)
  90. return data
  91. def set(self, section_name: str, data: t.Any) -> None:
  92. """Store the given config data."""
  93. filename = self.file_name(section_name)
  94. self.ensure_config_dir_exists()
  95. if self.read_directory:
  96. # we will modify data in place, so make a copy
  97. data = copy.deepcopy(data)
  98. defaults = self.get(section_name, include_root=False)
  99. remove_defaults(data, defaults)
  100. # Generate the JSON up front, since it could raise an exception,
  101. # in order to avoid writing half-finished corrupted data to disk.
  102. json_content = json.dumps(data, indent=2)
  103. with open(filename, "w", encoding="utf-8") as f:
  104. f.write(json_content)
  105. def update(self, section_name: str, new_data: t.Any) -> dict[str, t.Any]:
  106. """Modify the config section by recursively updating it with new_data.
  107. Returns the modified config data as a dictionary.
  108. """
  109. data = self.get(section_name)
  110. recursive_update(data, new_data)
  111. self.set(section_name, data)
  112. return data