sender_config.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from __future__ import annotations
  2. import json
  3. from collections.abc import Sequence
  4. from typing import Any, NewType
  5. from wandb.proto import wandb_internal_pb2
  6. from wandb.sdk.lib import proto_util, telemetry
  7. BackendConfigDict = NewType("BackendConfigDict", dict[str, Any])
  8. """Run config dictionary in the format used by the backend."""
  9. _WANDB_INTERNAL_KEY = "_wandb"
  10. class ConfigState:
  11. """The configuration of a run."""
  12. def __init__(self, tree: dict[str, Any] | None = None) -> None:
  13. self._tree: dict[str, Any] = tree or {}
  14. """A tree with string-valued nodes and JSON leaves.
  15. Leaves are Python objects that are valid JSON values:
  16. * Primitives like strings and numbers
  17. * Dictionaries from strings to JSON objects
  18. * Lists of JSON objects
  19. """
  20. def non_internal_config(self) -> dict[str, Any]:
  21. """Returns the config settings minus "_wandb"."""
  22. return {k: v for k, v in self._tree.items() if k != _WANDB_INTERNAL_KEY}
  23. def update_from_proto(
  24. self,
  25. config_record: wandb_internal_pb2.ConfigRecord,
  26. ) -> None:
  27. """Applies update and remove commands."""
  28. for config_item in config_record.update:
  29. self._update_at_path(
  30. _key_path(config_item),
  31. json.loads(config_item.value_json),
  32. )
  33. for config_item in config_record.remove:
  34. self._delete_at_path(_key_path(config_item))
  35. def merge_resumed_config(self, old_config_tree: dict[str, Any]) -> None:
  36. """Merges the config from a run that's being resumed."""
  37. # Add any top-level keys that aren't already set.
  38. self._add_unset_keys_from_subtree(old_config_tree, [])
  39. # When resuming a run, we want to ensure the some of the old configs keys
  40. # are maintained. So we have this logic here to add back
  41. # any keys that were in the old config but not in the new config
  42. for key in ["viz", "visualize", "mask/class_labels"]:
  43. self._add_unset_keys_from_subtree(
  44. old_config_tree,
  45. [_WANDB_INTERNAL_KEY, key],
  46. )
  47. def _add_unset_keys_from_subtree(
  48. self,
  49. old_config_tree: dict[str, Any],
  50. path: Sequence[str],
  51. ) -> None:
  52. """Uses the given subtree for keys that aren't already set."""
  53. old_subtree = _subtree(old_config_tree, path, create=False)
  54. if not old_subtree:
  55. return
  56. new_subtree = _subtree(self._tree, path, create=True)
  57. assert new_subtree is not None
  58. for key, value in old_subtree.items():
  59. if key not in new_subtree:
  60. new_subtree[key] = value
  61. def to_backend_dict(
  62. self,
  63. telemetry_record: telemetry.TelemetryRecord,
  64. framework: str | None,
  65. start_time_millis: int,
  66. metric_pbdicts: Sequence[dict[int, Any]],
  67. environment_record: wandb_internal_pb2.EnvironmentRecord,
  68. ) -> BackendConfigDict:
  69. """Returns a dictionary representation expected by the backend.
  70. The backend expects the configuration in a specific format, and the
  71. config is also used to store additional metadata about the run.
  72. Args:
  73. telemetry_record: Telemetry information to insert.
  74. framework: The detected framework used in the run (e.g. TensorFlow).
  75. start_time_millis: The run's start time in Unix milliseconds.
  76. metric_pbdicts: List of dict representations of metric protobuffers.
  77. """
  78. backend_dict = self._tree.copy()
  79. wandb_internal = backend_dict.setdefault(_WANDB_INTERNAL_KEY, {})
  80. ###################################################
  81. # Telemetry information
  82. ###################################################
  83. py_version = telemetry_record.python_version
  84. if py_version:
  85. wandb_internal["python_version"] = py_version
  86. cli_version = telemetry_record.cli_version
  87. if cli_version:
  88. wandb_internal["cli_version"] = cli_version
  89. if framework:
  90. wandb_internal["framework"] = framework
  91. huggingface_version = telemetry_record.huggingface_version
  92. if huggingface_version:
  93. wandb_internal["huggingface_version"] = huggingface_version
  94. wandb_internal["is_jupyter_run"] = telemetry_record.env.jupyter
  95. wandb_internal["is_kaggle_kernel"] = telemetry_record.env.kaggle
  96. wandb_internal["start_time"] = start_time_millis
  97. # The full telemetry record.
  98. wandb_internal["t"] = proto_util.proto_encode_to_dict(telemetry_record)
  99. ###################################################
  100. # Metrics
  101. ###################################################
  102. if metric_pbdicts:
  103. wandb_internal["m"] = metric_pbdicts
  104. ###################################################
  105. # Environment
  106. ###################################################
  107. writer_id = environment_record.writer_id
  108. if writer_id:
  109. environment_dict = proto_util.message_to_dict(environment_record)
  110. wandb_internal["e"] = {writer_id: environment_dict}
  111. return BackendConfigDict(
  112. {
  113. key: {
  114. # Configurations can be stored in a hand-written YAML file,
  115. # and users can add descriptions to their hyperparameters
  116. # there. However, we don't support a way to set descriptions
  117. # via code, so this is always None.
  118. "desc": None,
  119. "value": value,
  120. }
  121. for key, value in self._tree.items()
  122. }
  123. )
  124. def _update_at_path(
  125. self,
  126. key_path: Sequence[str],
  127. value: Any,
  128. ) -> None:
  129. """Sets the value at the path in the config tree."""
  130. subtree = _subtree(self._tree, key_path[:-1], create=True)
  131. assert subtree is not None
  132. subtree[key_path[-1]] = value
  133. def _delete_at_path(
  134. self,
  135. key_path: Sequence[str],
  136. ) -> None:
  137. """Removes the subtree at the path in the config tree."""
  138. subtree = _subtree(self._tree, key_path[:-1], create=False)
  139. if subtree:
  140. del subtree[key_path[-1]]
  141. def _key_path(config_item: wandb_internal_pb2.ConfigItem) -> Sequence[str]:
  142. """Returns the key path referenced by the config item."""
  143. if config_item.nested_key:
  144. return config_item.nested_key
  145. elif config_item.key:
  146. return [config_item.key]
  147. else:
  148. raise AssertionError(
  149. "Invalid ConfigItem: either key or nested_key must be set",
  150. )
  151. def _subtree(
  152. tree: dict[str, Any],
  153. key_path: Sequence[str],
  154. *,
  155. create: bool = False,
  156. ) -> dict[str, Any] | None:
  157. """Returns a subtree at the given path."""
  158. for key in key_path:
  159. subtree = tree.get(key)
  160. if not subtree:
  161. if create:
  162. subtree = {}
  163. tree[key] = subtree
  164. else:
  165. return None
  166. tree = subtree
  167. return tree