| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520 |
- """tensorboard watcher."""
- from __future__ import annotations
- import glob
- import logging
- import os
- import queue
- import socket
- import sys
- import threading
- import time
- from typing import TYPE_CHECKING, Any
- import wandb
- from wandb import util
- from wandb.plot import CustomChart
- from wandb.sdk.lib import filesystem
- from . import run as internal_run
- if TYPE_CHECKING:
- from queue import PriorityQueue
- from tensorboard.backend.event_processing.event_file_loader import EventFileLoader
- from tensorboard.compat.proto.event_pb2 import ProtoEvent
- from wandb.proto.wandb_internal_pb2 import RunRecord
- from wandb.sdk.lib.filesystem import FilesDict
- from ..interface.interface_queue import InterfaceQueue
- from .settings_static import SettingsStatic
- HistoryDict = dict[str, Any]
- # Give some time for tensorboard data to be flushed
- SHUTDOWN_DELAY = 5
- ERROR_DELAY = 5
- REMOTE_FILE_TOKEN = "://"
- logger = logging.getLogger(__name__)
- def _link_and_save_file(
- path: str, base_path: str, interface: InterfaceQueue, settings: SettingsStatic
- ) -> None:
- # TODO(jhr): should this logic be merged with Run.save()
- files_dir = settings.files_dir
- file_name = os.path.relpath(path, base_path)
- abs_path = os.path.abspath(path)
- wandb_path = os.path.join(files_dir, file_name)
- filesystem.mkdir_exists_ok(os.path.dirname(wandb_path))
- # We overwrite existing symlinks because namespaces can change in Tensorboard
- if os.path.islink(wandb_path) and abs_path != os.readlink(wandb_path):
- os.remove(wandb_path)
- os.symlink(abs_path, wandb_path)
- elif not os.path.exists(wandb_path):
- os.symlink(abs_path, wandb_path)
- # TODO(jhr): need to figure out policy, live/throttled?
- interface.publish_files(
- dict(files=[(filesystem.GlobStr(glob.escape(file_name)), "live")])
- )
- def is_tfevents_file_created_by(
- path: str, hostname: str | None, start_time: float | None
- ) -> bool:
- """Check if a path is a tfevents file.
- Optionally checks that it was created by [hostname] after [start_time].
- tensorboard tfevents filename format:
- https://github.com/tensorflow/tensorboard/blob/f3f26b46981da5bd46a5bb93fcf02d9eb7608bc1/tensorboard/summary/writer/event_file_writer.py#L81
- tensorflow tfevents filename format:
- https://github.com/tensorflow/tensorflow/blob/8f597046dc30c14b5413813d02c0e0aed399c177/tensorflow/core/util/events_writer.cc#L68
- """
- if not path:
- raise ValueError("Path must be a nonempty string")
- basename = os.path.basename(path)
- if basename.endswith((".profile-empty", ".sagemaker-uploaded")):
- return False
- fname_components = basename.split(".")
- try:
- tfevents_idx = fname_components.index("tfevents")
- except ValueError:
- return False
- # check the hostname, which may have dots
- if hostname is not None:
- for i, part in enumerate(hostname.split(".")):
- try:
- fname_component_part = fname_components[tfevents_idx + 2 + i]
- except IndexError:
- return False
- if part != fname_component_part:
- return False
- if start_time is not None:
- try:
- created_time = int(fname_components[tfevents_idx + 1])
- except (ValueError, IndexError):
- return False
- # Ensure that the file is newer then our start time, and that it was
- # created from the same hostname.
- # TODO: we should also check the PID (also contained in the tfevents
- # filename). Can we assume that our parent pid is the user process
- # that wrote these files?
- if created_time < int(start_time):
- return False
- return True
- class TBWatcher:
- _logdirs: dict[str, TBDirWatcher]
- _watcher_queue: PriorityQueue
- def __init__(
- self,
- settings: SettingsStatic,
- run_proto: RunRecord,
- interface: InterfaceQueue,
- force: bool = False,
- ) -> None:
- self._logdirs = {}
- self._consumer: TBEventConsumer | None = None
- self._settings = settings
- self._interface = interface
- self._run_proto = run_proto
- self._force = force
- # TODO(jhr): do we need locking in this queue?
- self._watcher_queue = queue.PriorityQueue()
- wandb.tensorboard.reset_state() # type: ignore
- def _calculate_namespace(self, logdir: str, rootdir: str) -> str | None:
- namespace: str | None
- dirs = list(self._logdirs) + [logdir]
- if os.path.isfile(logdir):
- filename = os.path.basename(logdir)
- else:
- filename = ""
- if rootdir == "":
- rootdir = util.to_forward_slash_path(
- os.path.dirname(os.path.commonprefix(dirs))
- )
- # Tensorboard loads all tfevents files in a directory and prepends
- # their values with the path. Passing namespace to log allows us
- # to nest the values in wandb
- # Note that we strip '/' instead of os.sep, because elsewhere we've
- # converted paths to forward slash.
- namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/")
- # TODO: revisit this heuristic, it exists because we don't know the
- # root log directory until more than one tfevents file is written to
- if len(dirs) == 1 and namespace not in ["train", "validation"]:
- namespace = None
- else:
- namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/")
- return namespace
- def add(self, logdir: str, save: bool, root_dir: str) -> None:
- logdir = util.to_forward_slash_path(logdir)
- root_dir = util.to_forward_slash_path(root_dir)
- if logdir in self._logdirs:
- return
- namespace = self._calculate_namespace(logdir, root_dir)
- # TODO(jhr): implement the deferred tbdirwatcher to find namespace
- if not self._consumer:
- self._consumer = TBEventConsumer(
- self, self._watcher_queue, self._run_proto, self._settings
- )
- self._consumer.start()
- tbdir_watcher = TBDirWatcher(
- self, logdir, save, namespace, self._watcher_queue, self._force
- )
- self._logdirs[logdir] = tbdir_watcher
- tbdir_watcher.start()
- def finish(self) -> None:
- for tbdirwatcher in self._logdirs.values():
- tbdirwatcher.shutdown()
- for tbdirwatcher in self._logdirs.values():
- tbdirwatcher.finish()
- if self._consumer:
- self._consumer.finish()
- class TBDirWatcher:
- def __init__(
- self,
- tbwatcher: TBWatcher,
- logdir: str,
- save: bool,
- namespace: str | None,
- queue: PriorityQueue,
- force: bool = False,
- ) -> None:
- self.directory_watcher = util.get_module(
- "tensorboard.backend.event_processing.directory_watcher",
- required="Please install tensorboard package",
- )
- # self.event_file_loader = util.get_module(
- # "tensorboard.backend.event_processing.event_file_loader",
- # required="Please install tensorboard package",
- # )
- self.tf_compat = util.get_module(
- "tensorboard.compat", required="Please install tensorboard package"
- )
- self._tbwatcher = tbwatcher
- self._generator = self.directory_watcher.DirectoryWatcher(
- logdir, self._loader(save, namespace), self._is_our_tfevents_file
- )
- self._thread = threading.Thread(target=self._thread_except_body)
- self._first_event_timestamp = None
- self._shutdown = threading.Event()
- self._queue = queue
- self._file_version = None
- self._namespace = namespace
- self._logdir = logdir
- self._hostname = socket.gethostname()
- self._force = force
- self._process_events_lock = threading.Lock()
- def start(self) -> None:
- self._thread.start()
- def _is_our_tfevents_file(self, path: str) -> bool:
- """Check if a path has been modified since launch and contains tfevents."""
- if not path:
- raise ValueError("Path must be a nonempty string")
- path = self.tf_compat.tf.compat.as_str_any(path)
- if self._force:
- return is_tfevents_file_created_by(path, None, None)
- else:
- return is_tfevents_file_created_by(
- path, self._hostname, self._tbwatcher._settings.x_start_time
- )
- def _loader(
- self, save: bool = True, namespace: str | None = None
- ) -> EventFileLoader:
- """Incredibly hacky class generator to optionally save / prefix tfevent files."""
- _loader_interface = self._tbwatcher._interface
- _loader_settings = self._tbwatcher._settings
- try:
- from tensorboard.backend.event_processing import event_file_loader
- except ImportError:
- raise Exception("Please install tensorboard package")
- class EventFileLoader(event_file_loader.EventFileLoader):
- def __init__(self, file_path: str) -> None:
- super().__init__(file_path)
- if save:
- if REMOTE_FILE_TOKEN in file_path:
- logger.warning(
- "Not persisting remote tfevent file: %s", file_path
- )
- else:
- # TODO: save plugins?
- logdir = os.path.dirname(file_path)
- parts = list(os.path.split(logdir))
- if namespace and parts[-1] == namespace:
- parts.pop()
- logdir = os.path.join(*parts)
- _link_and_save_file(
- path=file_path,
- base_path=logdir,
- interface=_loader_interface,
- settings=_loader_settings,
- )
- return EventFileLoader
- def _process_events(self, shutdown_call: bool = False) -> None:
- try:
- with self._process_events_lock:
- for event in self._generator.Load():
- self.process_event(event)
- except (
- self.directory_watcher.DirectoryDeletedError,
- StopIteration,
- RuntimeError,
- OSError,
- ) as e:
- # When listing s3 the directory may not yet exist, or could be empty
- logger.debug("Encountered tensorboard directory watcher error: %s", e)
- if not self._shutdown.is_set() and not shutdown_call:
- time.sleep(ERROR_DELAY)
- def _thread_except_body(self) -> None:
- try:
- self._thread_body()
- except Exception:
- logger.exception("generic exception in TBDirWatcher thread")
- raise
- def _thread_body(self) -> None:
- """Check for new events every second."""
- shutdown_time: float | None = None
- while True:
- self._process_events()
- if self._shutdown.is_set():
- now = time.time()
- if not shutdown_time:
- shutdown_time = now + SHUTDOWN_DELAY
- elif now > shutdown_time:
- break
- time.sleep(1)
- def process_event(self, event: ProtoEvent) -> None:
- # print("\nEVENT:::", self._logdir, self._namespace, event, "\n")
- if self._first_event_timestamp is None:
- self._first_event_timestamp = event.wall_time
- if event.HasField("file_version"):
- self._file_version = event.file_version
- if event.HasField("summary"):
- self._queue.put(Event(event, self._namespace))
- def shutdown(self) -> None:
- self._process_events(shutdown_call=True)
- self._shutdown.set()
- def finish(self) -> None:
- self.shutdown()
- self._thread.join()
- class Event:
- """An event wrapper to enable priority queueing."""
- def __init__(self, event: ProtoEvent, namespace: str | None):
- self.event = event
- self.namespace = namespace
- self.created_at = time.time()
- def __lt__(self, other: Event) -> bool:
- return self.event.wall_time < other.event.wall_time
- class TBEventConsumer:
- """Consume tfevents from a priority queue.
- There should always only be one of these per run_manager. We wait for 10 seconds of
- queued events to reduce the chance of multiple tfevent files triggering out of order
- steps.
- """
- def __init__(
- self,
- tbwatcher: TBWatcher,
- queue: PriorityQueue,
- run_proto: RunRecord,
- settings: SettingsStatic,
- delay: int = 10,
- ) -> None:
- self._tbwatcher = tbwatcher
- self._queue = queue
- self._thread = threading.Thread(target=self._thread_except_body)
- self._shutdown = threading.Event()
- self.tb_history = TBHistory()
- self._delay = delay
- # This is a bit of a hack to get file saving to work as it does in the user
- # process. Since we don't have a real run object, we have to define the
- # datatypes callback ourselves.
- def datatypes_cb(fname: filesystem.GlobStr) -> None:
- files: FilesDict = dict(files=[(fname, "now")])
- self._tbwatcher._interface.publish_files(files)
- # this is only used for logging artifacts
- self._internal_run = internal_run.InternalRun(run_proto, settings, datatypes_cb)
- self._internal_run._set_internal_run_interface(self._tbwatcher._interface)
- def start(self) -> None:
- self._start_time = time.time()
- self._thread.start()
- def finish(self) -> None:
- self._delay = 0
- self._shutdown.set()
- self._thread.join()
- while not self._queue.empty():
- event = self._queue.get(True, 1)
- if event:
- self._handle_event(event, history=self.tb_history)
- items = self.tb_history._get_and_reset()
- for item in items:
- self._save_row(
- item,
- )
- def _thread_except_body(self) -> None:
- try:
- self._thread_body()
- except Exception:
- logger.exception("generic exception in TBEventConsumer thread")
- raise
- def _thread_body(self) -> None:
- while True:
- try:
- event = self._queue.get(True, 1)
- # Wait self._delay seconds from consumer start before logging events
- if (
- time.time() < self._start_time + self._delay
- and not self._shutdown.is_set()
- ):
- self._queue.put(event)
- time.sleep(0.1)
- continue
- except queue.Empty:
- event = None
- if self._shutdown.is_set():
- break
- if event:
- self._handle_event(event, history=self.tb_history)
- items = self.tb_history._get_and_reset()
- for item in items:
- self._save_row(
- item,
- )
- # flush uncommitted data
- self.tb_history._flush()
- items = self.tb_history._get_and_reset()
- for item in items:
- self._save_row(item)
- def _handle_event(
- self, event: ProtoEvent, history: TBHistory | None = None
- ) -> None:
- wandb.tensorboard._log( # type: ignore
- event.event,
- step=event.event.step,
- namespace=event.namespace,
- history=history,
- )
- def _save_row(self, row: HistoryDict) -> None:
- chart_keys = set()
- for k, v in row.items():
- if isinstance(v, CustomChart):
- chart_keys.add(k)
- v.set_key(k)
- self._tbwatcher._interface.publish_config(
- key=v.spec.config_key,
- val=v.spec.config_value,
- )
- for k in chart_keys:
- chart = row.pop(k)
- if isinstance(chart, CustomChart):
- row[chart.spec.table_key] = chart.table
- self._tbwatcher._interface.publish_history(
- self._internal_run,
- row,
- publish_step=False,
- )
- class TBHistory:
- _data: HistoryDict
- _added: list[HistoryDict]
- def __init__(self) -> None:
- self._step = 0
- self._step_size = 0
- self._data = dict()
- self._added = []
- def _flush(self) -> None:
- if not self._data:
- return
- # A single tensorboard step may have too much data
- # we just drop the largest keys in the step if it does.
- # TODO: we could flush the data across multiple steps
- if self._step_size > util.MAX_LINE_BYTES:
- metrics = [(k, sys.getsizeof(v)) for k, v in self._data.items()]
- metrics.sort(key=lambda t: t[1], reverse=True)
- bad = 0
- dropped_keys = []
- for k, v in metrics:
- # TODO: (cvp) Added a buffer of 100KiB, this feels rather brittle.
- if self._step_size - bad < util.MAX_LINE_BYTES - 100000:
- break
- else:
- bad += v
- dropped_keys.append(k)
- del self._data[k]
- wandb.termwarn(
- f"Step {self._step} exceeds max data limit, dropping {len(dropped_keys)} of the largest keys:"
- )
- print("\t" + ("\n\t".join(dropped_keys))) # noqa: T201
- self._data["_step"] = self._step
- self._added.append(self._data)
- self._step += 1
- self._step_size = 0
- def add(self, d: HistoryDict) -> None:
- self._flush()
- self._data = dict()
- self._data.update(self._track_history_dict(d))
- def _track_history_dict(self, d: HistoryDict) -> HistoryDict:
- e = {}
- for k in d:
- e[k] = d[k]
- self._step_size += sys.getsizeof(e[k])
- return e
- def _row_update(self, d: HistoryDict) -> None:
- self._data.update(self._track_history_dict(d))
- def _get_and_reset(self) -> list[HistoryDict]:
- added = self._added[:]
- self._added = []
- return added
|