"""Handle Manager.""" from __future__ import annotations import json import logging import math import numbers import time from collections import defaultdict from collections.abc import Iterable, Sequence from queue import Queue from threading import Event from typing import TYPE_CHECKING, Any, Callable, cast from wandb.errors.links import url_registry from wandb.proto.wandb_internal_pb2 import ( HistoryRecord, InternalMessages, MetricRecord, Record, Result, RunRecord, SampledHistoryItem, SummaryItem, SummaryRecord, SummaryRecordRequest, ) from ..interface.interface_queue import InterfaceQueue from ..lib import handler_util, proto_util from . import context, sample, tb_watcher from .settings_static import SettingsStatic if TYPE_CHECKING: from wandb.proto.wandb_internal_pb2 import MetricSummary SummaryDict = dict[str, Any] logger = logging.getLogger(__name__) # Update (March 5, 2024): Since ~2020/2021, when constructing the summary # object, we had replaced the artifact path for media types with the latest # artifact path. The primary purpose of this was to support live updating of # media objects in the UI (since the default artifact path was fully qualified # and would not update). However, in March of 2024, a bug was discovered with # this approach which causes this path to be incorrect in cases where the media # object is logged to another artifact before being logged to the run. Setting # this to `False` disables this copy behavior. The impact is that users will # need to refresh to see updates. Ironically, this updating behavior is not # currently supported in the UI, so the impact of this change is minimal. REPLACE_SUMMARY_ART_PATH_WITH_LATEST = False def _dict_nested_set(target: dict[str, Any], key_list: Sequence[str], v: Any) -> None: # recurse down the dictionary structure: for k in key_list[:-1]: target.setdefault(k, {}) new_target = target.get(k) if TYPE_CHECKING: new_target = cast(dict[str, Any], new_target) target = new_target # use the last element of the key to write the leaf: target[key_list[-1]] = v class HandleManager: _consolidated_summary: SummaryDict _sampled_history: dict[str, sample.UniformSampleAccumulator] _partial_history: dict[str, Any] _run_proto: RunRecord | None _settings: SettingsStatic _record_q: Queue[Record] _result_q: Queue[Result] _stopped: Event _writer_q: Queue[Record] _interface: InterfaceQueue _tb_watcher: tb_watcher.TBWatcher | None _metric_defines: dict[str, MetricRecord] _metric_globs: dict[str, MetricRecord] _metric_track: dict[tuple[str, ...], float] _metric_copy: dict[tuple[str, ...], Any] _track_time: float | None _accumulate_time: float _run_start_time: float | None _context_keeper: context.ContextKeeper def __init__( self, settings: SettingsStatic, record_q: Queue[Record], result_q: Queue[Result], stopped: Event, writer_q: Queue[Record], interface: InterfaceQueue, context_keeper: context.ContextKeeper, ) -> None: self._settings = settings self._record_q = record_q self._result_q = result_q self._stopped = stopped self._writer_q = writer_q self._interface = interface self._context_keeper = context_keeper self._tb_watcher = None self._step = 0 self._track_time = None self._accumulate_time = 0 self._run_start_time = None # keep track of summary from key/val updates self._consolidated_summary = dict() self._sampled_history = defaultdict(sample.UniformSampleAccumulator) self._run_proto = None self._partial_history = dict() self._metric_defines = defaultdict(MetricRecord) self._metric_globs = defaultdict(MetricRecord) self._metric_track = dict() self._metric_copy = dict() self._internal_messages = InternalMessages() self._dropped_history = False def __len__(self) -> int: return self._record_q.qsize() def handle(self, record: Record) -> None: self._context_keeper.add_from_record(record) record_type = record.WhichOneof("record_type") assert record_type handler_str = "handle_" + record_type handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore assert handler, f"unknown handle: {handler_str}" # type: ignore handler(record) def handle_request(self, record: Record) -> None: request_type = record.request.WhichOneof("request_type") assert request_type handler_str = "handle_request_" + request_type handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore if request_type != "network_status": logger.debug(f"handle_request: {request_type}") assert handler, f"unknown handle: {handler_str}" # type: ignore handler(record) def _dispatch_record(self, record: Record, always_send: bool = False) -> None: if always_send: record.control.always_send = True self._writer_q.put(record) def _respond_result(self, result: Result) -> None: context_id = context.context_id_from_result(result) self._context_keeper.release(context_id) self._result_q.put(result) def debounce(self) -> None: pass def handle_request_cancel(self, record: Record) -> None: self._dispatch_record(record) def handle_request_defer(self, record: Record) -> None: defer = record.request.defer state = defer.state logger.info(f"handle defer: {state}") if state == defer.FLUSH_TB: if self._tb_watcher: # shutdown tensorboard workers so we get all metrics flushed self._tb_watcher.finish() self._tb_watcher = None elif state == defer.FLUSH_PARTIAL_HISTORY: self._flush_partial_history() elif state == defer.FLUSH_SUM: self._save_summary(self._consolidated_summary, flush=True) # defer is used to drive the sender finish state machine self._dispatch_record(record, always_send=True) def handle_request_python_packages(self, record: Record) -> None: self._dispatch_record(record) def handle_run(self, record: Record) -> None: if self._settings._offline: self._run_proto = record.run result = proto_util._result_from_record(record) result.run_result.run.CopyFrom(record.run) self._respond_result(result) self._dispatch_record(record) def handle_stats(self, record: Record) -> None: self._dispatch_record(record) def handle_config(self, record: Record) -> None: self._dispatch_record(record) def handle_output(self, record: Record) -> None: self._dispatch_record(record) def handle_output_raw(self, record: Record) -> None: self._dispatch_record(record) def handle_files(self, record: Record) -> None: self._dispatch_record(record) def handle_request_link_artifact(self, record: Record) -> None: self._dispatch_record(record) def handle_use_artifact(self, record: Record) -> None: self._dispatch_record(record) def handle_artifact(self, record: Record) -> None: self._dispatch_record(record) def handle_alert(self, record: Record) -> None: self._dispatch_record(record) def _save_summary(self, summary_dict: SummaryDict, flush: bool = False) -> None: summary = SummaryRecord() for k, v in summary_dict.items(): update = summary.update.add() update.key = k update.value_json = json.dumps(v) if flush: record = Record(summary=summary) self._dispatch_record(record) elif not self._settings._offline: # Send this summary update as a request since we aren't persisting every update summary_record = SummaryRecordRequest(summary=summary) request_record = self._interface._make_request( summary_record=summary_record ) self._dispatch_record(request_record) def _save_history( self, history: HistoryRecord, ) -> None: for item in history.item: # TODO(jhr) save nested keys? k = item.key v = json.loads(item.value_json) if isinstance(v, numbers.Real): self._sampled_history[k].add(v) def _update_summary_metrics( self, s: MetricSummary, kl: list[str], v: numbers.Real, float_v: float, goal_max: bool | None, ) -> bool: updated = False best_key: tuple[str, ...] | None = None if s.none: return False if s.copy and len(kl) > 1: # non-key list copy already done in _update_summary _dict_nested_set(self._consolidated_summary, kl, v) return True if s.last: last_key = tuple(kl + ["last"]) old_last = self._metric_track.get(last_key) if old_last is None or float_v != old_last: self._metric_track[last_key] = float_v _dict_nested_set(self._consolidated_summary, last_key, v) updated = True if s.best: best_key = tuple(kl + ["best"]) if s.max or best_key and goal_max: max_key = tuple(kl + ["max"]) old_max = self._metric_track.get(max_key) if old_max is None or float_v > old_max: self._metric_track[max_key] = float_v if s.max: _dict_nested_set(self._consolidated_summary, max_key, v) updated = True if best_key: _dict_nested_set(self._consolidated_summary, best_key, v) updated = True # defaulting to minimize if goal is not specified if s.min or best_key and not goal_max: min_key = tuple(kl + ["min"]) old_min = self._metric_track.get(min_key) if old_min is None or float_v < old_min: self._metric_track[min_key] = float_v if s.min: _dict_nested_set(self._consolidated_summary, min_key, v) updated = True if best_key: _dict_nested_set(self._consolidated_summary, best_key, v) updated = True if s.mean: tot_key = tuple(kl + ["tot"]) num_key = tuple(kl + ["num"]) avg_key = tuple(kl + ["mean"]) tot = self._metric_track.get(tot_key, 0.0) num = self._metric_track.get(num_key, 0) tot += float_v num += 1 self._metric_track[tot_key] = tot self._metric_track[num_key] = num _dict_nested_set(self._consolidated_summary, avg_key, tot / num) updated = True return updated def _update_summary_leaf( self, kl: list[str], v: Any, d: MetricRecord | None = None, ) -> bool: has_summary = d and d.HasField("summary") if len(kl) == 1: copy_key = tuple(kl) old_copy = self._metric_copy.get(copy_key) if old_copy is None or v != old_copy: self._metric_copy[copy_key] = v # Store copy metric if not specified, or copy behavior if not has_summary or (d and d.summary.copy): self._consolidated_summary[kl[0]] = v return True if not d: return False if not has_summary: return False if not isinstance(v, numbers.Real): return False if math.isnan(v): return False float_v = float(v) goal_max = None if d.goal: goal_max = d.goal == d.GOAL_MAXIMIZE return bool( self._update_summary_metrics( d.summary, kl=kl, v=v, float_v=float_v, goal_max=goal_max ) ) def _update_summary_list( self, kl: list[str], v: Any, d: MetricRecord | None = None, ) -> bool: metric_key = ".".join([k.replace(".", "\\.") for k in kl]) d = self._metric_defines.get(metric_key, d) # if the dict has _type key, it's a wandb table object if isinstance(v, dict) and not handler_util.metric_is_wandb_dict(v): updated = False for nk, nv in v.items(): if self._update_summary_list(kl=kl[:] + [nk], v=nv, d=d): updated = True return updated # If the dict is a media object, update the pointer to the latest alias elif ( REPLACE_SUMMARY_ART_PATH_WITH_LATEST and isinstance(v, dict) and handler_util.metric_is_wandb_dict(v) ): if "_latest_artifact_path" in v and "artifact_path" in v: # TODO: Make non-destructive? v["artifact_path"] = v["_latest_artifact_path"] updated = self._update_summary_leaf(kl=kl, v=v, d=d) return updated def _update_summary_media_objects(self, v: dict[str, Any]) -> dict[str, Any]: # For now, non-recursive - just top level for nk, nv in v.items(): if REPLACE_SUMMARY_ART_PATH_WITH_LATEST and ( isinstance(nv, dict) and handler_util.metric_is_wandb_dict(nv) and "_latest_artifact_path" in nv and "artifact_path" in nv ): # TODO: Make non-destructive? nv["artifact_path"] = nv["_latest_artifact_path"] v[nk] = nv return v def _update_summary(self, history_dict: dict[str, Any]) -> list[str]: # keep old behavior fast path if no define metrics have been used if not self._metric_defines: history_dict = self._update_summary_media_objects(history_dict) self._consolidated_summary.update(history_dict) return list(history_dict.keys()) updated_keys = [] for k, v in history_dict.items(): if self._update_summary_list(kl=[k], v=v): updated_keys.append(k) return updated_keys def _history_assign_step( self, history: HistoryRecord, history_dict: dict[str, Any], ) -> None: has_step = history.HasField("step") item = history.item.add() item.key = "_step" if has_step: step = history.step.num history_dict["_step"] = step item.value_json = json.dumps(step) self._step = step + 1 else: history_dict["_step"] = self._step item.value_json = json.dumps(self._step) self._step += 1 def _history_define_metric(self, hkey: str) -> MetricRecord | None: """Check for hkey match in glob metrics and return the defined metric.""" # Dont define metric for internal metrics if hkey.startswith("_"): return None for k, mglob in self._metric_globs.items(): if k.endswith("*") and hkey.startswith(k[:-1]): m = MetricRecord() m.CopyFrom(mglob) m.ClearField("glob_name") m.options.defined = False m.name = hkey return m return None def _history_update_leaf( self, kl: list[str], v: Any, history_dict: dict[str, Any], update_history: dict[str, Any], ) -> None: hkey = ".".join([k.replace(".", "\\.") for k in kl]) m = self._metric_defines.get(hkey) if not m: m = self._history_define_metric(hkey) if not m: return mr = Record() mr.metric.CopyFrom(m) mr.control.local = True # Dont store this, just send it self._handle_defined_metric(mr) if m.options.step_sync and m.step_metric and m.step_metric not in history_dict: copy_key = tuple([m.step_metric]) step = self._metric_copy.get(copy_key) if step is not None: update_history[m.step_metric] = step def _history_update_list( self, kl: list[str], v: Any, history_dict: dict[str, Any], update_history: dict[str, Any], ) -> None: if isinstance(v, dict): for nk, nv in v.items(): self._history_update_list( kl=kl[:] + [nk], v=nv, history_dict=history_dict, update_history=update_history, ) return self._history_update_leaf( kl=kl, v=v, history_dict=history_dict, update_history=update_history ) def _history_update( self, history: HistoryRecord, history_dict: dict[str, Any], ) -> None: # if syncing an old run, we can skip this logic if history_dict.get("_step") is None: self._history_assign_step(history, history_dict) update_history: dict[str, Any] = {} # Look for metric matches if self._metric_defines or self._metric_globs: for hkey, hval in history_dict.items(): self._history_update_list([hkey], hval, history_dict, update_history) if update_history: history_dict.update(update_history) for k, v in update_history.items(): item = history.item.add() item.key = k item.value_json = json.dumps(v) def handle_history(self, record: Record) -> None: history_dict = proto_util.dict_from_proto_list(record.history.item) # Inject _runtime if it is not present if history_dict is not None and "_runtime" not in history_dict: self._history_assign_runtime(record.history, history_dict) self._history_update(record.history, history_dict) self._dispatch_record(record) self._save_history(record.history) # update summary from history updated_keys = self._update_summary(history_dict) if updated_keys: updated_items = {k: self._consolidated_summary[k] for k in updated_keys} self._save_summary(updated_items) def _flush_partial_history( self, step: int | None = None, ) -> None: if not self._partial_history: return history = HistoryRecord() for k, v in self._partial_history.items(): item = history.item.add() item.key = k item.value_json = json.dumps(v) if step is not None: history.step.num = step self.handle_history(Record(history=history)) self._partial_history = {} def handle_request_sender_mark_report(self, record: Record) -> None: self._dispatch_record(record, always_send=True) def handle_request_status_report(self, record: Record) -> None: self._dispatch_record(record, always_send=True) def handle_request_partial_history(self, record: Record) -> None: partial_history = record.request.partial_history flush = None if partial_history.HasField("action"): flush = partial_history.action.flush step = None if partial_history.HasField("step"): step = partial_history.step.num history_dict = proto_util.dict_from_proto_list(partial_history.item) if step is not None: if step < self._step: if not self._dropped_history: message = ( "Step only supports monotonically increasing values, use define_metric to set a custom x " f"axis. For details see: {url_registry.url('define-metric')}" ) self._internal_messages.warning.append(message) self._dropped_history = True message = ( f"(User provided step: {step} is less than current step: {self._step}. " f"Dropping entry: {history_dict})." ) self._internal_messages.warning.append(message) return elif step > self._step: self._flush_partial_history() self._step = step elif flush is None: flush = True self._partial_history.update(history_dict) if flush: self._flush_partial_history(self._step) def handle_summary(self, record: Record) -> None: summary = record.summary for item in summary.update: if len(item.nested_key) > 0: # we use either key or nested_key -- not both assert item.key == "" key = tuple(item.nested_key) else: # no counter-assertion here, because technically # summary[""] is valid key = (item.key,) target = self._consolidated_summary # recurse down the dictionary structure: for prop in key[:-1]: target = target[prop] # use the last element of the key to write the leaf: target[key[-1]] = json.loads(item.value_json) for item in summary.remove: if len(item.nested_key) > 0: # we use either key or nested_key -- not both assert item.key == "" key = tuple(item.nested_key) else: # no counter-assertion here, because technically # summary[""] is valid key = (item.key,) target = self._consolidated_summary # recurse down the dictionary structure: for prop in key[:-1]: target = target[prop] # use the last element of the key to erase the leaf: del target[key[-1]] self._save_summary(self._consolidated_summary) def handle_exit(self, record: Record) -> None: if self._track_time is not None: self._accumulate_time += time.time() - self._track_time record.exit.runtime = int(self._accumulate_time) self._dispatch_record(record, always_send=True) def handle_final(self, record: Record) -> None: self._dispatch_record(record, always_send=True) def handle_preempting(self, record: Record) -> None: self._dispatch_record(record) def handle_header(self, record: Record) -> None: self._dispatch_record(record) def handle_footer(self, record: Record) -> None: self._dispatch_record(record) def handle_metadata(self, record: Record) -> None: self._dispatch_record(record) def handle_request_attach(self, record: Record) -> None: result = proto_util._result_from_record(record) attach_id = record.request.attach.attach_id assert attach_id assert self._run_proto result.response.attach_response.run.CopyFrom(self._run_proto) self._respond_result(result) def handle_request_log_artifact(self, record: Record) -> None: self._dispatch_record(record) def handle_telemetry(self, record: Record) -> None: self._dispatch_record(record) def handle_request_run_start(self, record: Record) -> None: run_start = record.request.run_start assert run_start assert run_start.run self._run_proto = run_start.run self._run_start_time = run_start.run.start_time.ToMicroseconds() / 1e6 self._track_time = time.time() if run_start.run.resumed and run_start.run.runtime: self._accumulate_time = run_start.run.runtime else: self._accumulate_time = 0 self._tb_watcher = tb_watcher.TBWatcher( self._settings, interface=self._interface, run_proto=run_start.run ) if run_start.run.resumed or run_start.run.forked: self._step = run_start.run.starting_step result = proto_util._result_from_record(record) self._respond_result(result) def handle_request_resume(self, record: Record) -> None: if self._track_time is not None: self._accumulate_time += time.time() - self._track_time self._track_time = time.time() def handle_request_pause(self, record: Record) -> None: if self._track_time is not None: self._accumulate_time += time.time() - self._track_time self._track_time = None def handle_request_poll_exit(self, record: Record) -> None: self._dispatch_record(record, always_send=True) def handle_request_stop_status(self, record: Record) -> None: self._dispatch_record(record) def handle_request_network_status(self, record: Record) -> None: self._dispatch_record(record) def handle_request_internal_messages(self, record: Record) -> None: result = proto_util._result_from_record(record) result.response.internal_messages_response.messages.CopyFrom( self._internal_messages ) self._internal_messages.Clear() self._respond_result(result) def handle_request_status(self, record: Record) -> None: result = proto_util._result_from_record(record) self._respond_result(result) def handle_request_get_summary(self, record: Record) -> None: result = proto_util._result_from_record(record) for key, value in self._consolidated_summary.items(): item = SummaryItem() item.key = key item.value_json = json.dumps(value) result.response.get_summary_response.item.append(item) self._respond_result(result) def handle_tbrecord(self, record: Record) -> None: logger.info("handling tbrecord: %s", record) if self._tb_watcher: tbrecord = record.tbrecord self._tb_watcher.add(tbrecord.log_dir, tbrecord.save, tbrecord.root_dir) self._dispatch_record(record) def _handle_defined_metric(self, record: Record) -> None: metric = record.metric if metric._control.overwrite: self._metric_defines[metric.name].CopyFrom(metric) else: self._metric_defines[metric.name].MergeFrom(metric) # before dispatching, make sure step_metric is defined, if not define it and # dispatch it locally first metric = self._metric_defines[metric.name] if metric.step_metric and metric.step_metric not in self._metric_defines: m = MetricRecord(name=metric.step_metric) self._metric_defines[metric.step_metric] = m mr = Record() mr.metric.CopyFrom(m) mr.control.local = True # Don't store this, just send it self._dispatch_record(mr) self._dispatch_record(record) def _handle_glob_metric(self, record: Record) -> None: metric = record.metric if metric._control.overwrite: self._metric_globs[metric.glob_name].CopyFrom(metric) else: self._metric_globs[metric.glob_name].MergeFrom(metric) self._dispatch_record(record) def handle_metric(self, record: Record) -> None: """Handle MetricRecord. Walkthrough of the life of a MetricRecord: Metric defined: - run.define_metric() parses arguments create wandb_metric.Metric - build MetricRecord publish to interface - handler (this function) keeps list of metrics published: - self._metric_defines: Fully defined metrics - self._metric_globs: metrics that have a wildcard - dispatch writer and sender thread - writer: records are saved to persistent store - sender: fully defined metrics get mapped into metadata for UI History logged: - handle_history - check if metric matches _metric_defines - if not, check if metric matches _metric_globs - if _metric globs match, generate defined metric and call _handle_metric Args: record (Record): Metric record to process """ if record.metric.name: self._handle_defined_metric(record) elif record.metric.glob_name: self._handle_glob_metric(record) def handle_request_sampled_history(self, record: Record) -> None: result = proto_util._result_from_record(record) for key, sampled in self._sampled_history.items(): item = SampledHistoryItem() item.key = key values: Iterable[Any] = sampled.get() if all(isinstance(i, numbers.Integral) for i in values): try: item.values_int.extend(values) except ValueError: # it is safe to ignore these as this is for display information pass elif all(isinstance(i, numbers.Real) for i in values): item.values_float.extend(values) result.response.sampled_history_response.item.append(item) self._respond_result(result) def handle_request_keepalive(self, record: Record) -> None: """Handle a keepalive request. Keepalive is a noop, we just want to verify transport is alive. """ def handle_request_run_status(self, record: Record) -> None: self._dispatch_record(record, always_send=True) def handle_request_shutdown(self, record: Record) -> None: # TODO(jhr): should we drain things and stop new requests from coming in? result = proto_util._result_from_record(record) self._respond_result(result) self._stopped.set() def handle_request_operations(self, record: Record) -> None: """No-op. Not implemented for the legacy-service.""" self._respond_result(proto_util._result_from_record(record)) def finish(self) -> None: logger.info("shutting down handler") if self._tb_watcher: self._tb_watcher.finish() # self._context_keeper._debug_print_orphans() def __next__(self) -> Record: return self._record_q.get(block=True) next = __next__ def _history_assign_runtime( self, history: HistoryRecord, history_dict: dict[str, Any], ) -> None: # _runtime calculation is meaningless if there is no _timestamp if "_timestamp" not in history_dict: return # if it is offline sync, self._run_start_time is None # in that case set it to the first tfevent timestamp if self._run_start_time is None: self._run_start_time = history_dict["_timestamp"] history_dict["_runtime"] = history_dict["_timestamp"] - self._run_start_time item = history.item.add() item.key = "_runtime" item.value_json = json.dumps(history_dict[item.key])