| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- """monkeypatch: patch code to add tensorboard hooks."""
- from __future__ import annotations
- import os
- import re
- import socket
- from typing import Any
- import wandb
- import wandb.util
- TENSORBOARD_C_MODULE = "tensorflow.python.ops.gen_summary_ops"
- TENSORBOARD_X_MODULE = "tensorboardX.writer"
- TENSORFLOW_PY_MODULE = "tensorflow.python.summary.writer.writer"
- TENSORBOARD_WRITER_MODULE = "tensorboard.summary.writer.event_file_writer"
- TENSORBOARD_PYTORCH_MODULE = "torch.utils.tensorboard.writer"
- def unpatch() -> None:
- for module, method in wandb.patched["tensorboard"]:
- writer = wandb.util.get_module(module, lazy=False)
- setattr(writer, method, getattr(writer, f"orig_{method}"))
- wandb.patched["tensorboard"] = []
- def patch(
- save: bool = True,
- tensorboard_x: bool | None = None,
- pytorch: bool | None = None,
- root_logdir: str = "",
- ) -> None:
- if len(wandb.patched["tensorboard"]) > 0:
- raise ValueError(
- "Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; "
- "remove `sync_tensorboard=True` from `wandb.init`; "
- "or only call `wandb.tensorboard.patch` once."
- )
- # TODO: Some older versions of tensorflow don't require tensorboard to be present.
- # we may want to lift this requirement, but it's safer to have it for now
- wandb.util.get_module(
- "tensorboard", required="Please install tensorboard package", lazy=False
- )
- c_writer = wandb.util.get_module(TENSORBOARD_C_MODULE, lazy=False)
- py_writer = wandb.util.get_module(TENSORFLOW_PY_MODULE, lazy=False)
- tb_writer = wandb.util.get_module(TENSORBOARD_WRITER_MODULE, lazy=False)
- pt_writer = wandb.util.get_module(TENSORBOARD_PYTORCH_MODULE, lazy=False)
- tbx_writer = wandb.util.get_module(TENSORBOARD_X_MODULE, lazy=False)
- if not pytorch and not tensorboard_x and c_writer:
- _patch_tensorflow2(
- writer=c_writer,
- module=TENSORBOARD_C_MODULE,
- save=save,
- root_logdir=root_logdir,
- )
- # This is for tensorflow <= 1.15 (tf.compat.v1.summary.FileWriter)
- if py_writer:
- _patch_file_writer(
- writer=py_writer,
- module=TENSORFLOW_PY_MODULE,
- save=save,
- root_logdir=root_logdir,
- )
- if tb_writer:
- _patch_file_writer(
- writer=tb_writer,
- module=TENSORBOARD_WRITER_MODULE,
- save=save,
- root_logdir=root_logdir,
- )
- if pt_writer:
- _patch_file_writer(
- writer=pt_writer,
- module=TENSORBOARD_PYTORCH_MODULE,
- save=save,
- root_logdir=root_logdir,
- )
- if tbx_writer:
- _patch_file_writer(
- writer=tbx_writer,
- module=TENSORBOARD_X_MODULE,
- save=save,
- root_logdir=root_logdir,
- )
- if not c_writer and not tb_writer and not tb_writer:
- wandb.termerror("Unsupported tensorboard configuration")
- def _patch_tensorflow2(
- writer: Any,
- module: Any,
- save: bool = True,
- root_logdir: str = "",
- ) -> None:
- # This configures TensorFlow 2 style Tensorboard logging
- old_csfw_func = writer.create_summary_file_writer
- logdir_hist = []
- def new_csfw_func(*args: Any, **kwargs: Any) -> Any:
- logdir = (
- kwargs["logdir"].numpy().decode("utf8")
- if hasattr(kwargs["logdir"], "numpy")
- else kwargs["logdir"]
- )
- logdir_hist.append(logdir)
- root_logdir_arg = root_logdir
- if len(set(logdir_hist)) > 1 and root_logdir == "":
- wandb.termwarn(
- "When using several event log directories, "
- 'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
- )
- # if the logdir contains the hostname, the writer was not given a logdir.
- # In this case, the generated logdir
- # is generated and ends with the hostname, update the root_logdir to match.
- hostname = socket.gethostname()
- search = re.search(rf"-\d+_{hostname}", logdir)
- if search:
- root_logdir_arg = logdir[: search.span()[1]]
- elif root_logdir is not None and not os.path.abspath(logdir).startswith(
- os.path.abspath(root_logdir)
- ):
- wandb.termwarn(
- "Found log directory outside of given root_logdir, "
- f"dropping given root_logdir for event file in {logdir}"
- )
- root_logdir_arg = ""
- _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
- return old_csfw_func(*args, **kwargs)
- writer.orig_create_summary_file_writer = old_csfw_func
- writer.create_summary_file_writer = new_csfw_func
- wandb.patched["tensorboard"].append([module, "create_summary_file_writer"])
- def _patch_file_writer(
- writer: Any,
- module: Any,
- save: bool = True,
- root_logdir: str = "",
- ) -> None:
- # This configures non-TensorFlow Tensorboard logging, or tensorflow <= 1.15
- logdir_hist = []
- class TBXEventFileWriter(writer.EventFileWriter):
- def __init__(self, logdir: str, *args: Any, **kwargs: Any) -> None:
- logdir_hist.append(logdir)
- root_logdir_arg = root_logdir
- if len(set(logdir_hist)) > 1 and root_logdir == "":
- wandb.termwarn(
- "When using several event log directories, "
- 'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
- )
- # if the logdir contains the hostname, the writer was not given a logdir.
- # In this case, the logdir is generated and ends with the hostname,
- # update the root_logdir to match.
- hostname = socket.gethostname()
- search = re.search(rf"-\d+_{hostname}", logdir)
- if search:
- root_logdir_arg = logdir[: search.span()[1]]
- elif root_logdir is not None and not os.path.abspath(logdir).startswith(
- os.path.abspath(root_logdir)
- ):
- wandb.termwarn(
- "Found log directory outside of given root_logdir, "
- f"dropping given root_logdir for event file in {logdir}"
- )
- root_logdir_arg = ""
- _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
- super().__init__(logdir, *args, **kwargs)
- writer.orig_EventFileWriter = writer.EventFileWriter
- writer.EventFileWriter = TBXEventFileWriter
- wandb.patched["tensorboard"].append([module, "EventFileWriter"])
- def _notify_tensorboard_logdir(
- logdir: str, save: bool = True, root_logdir: str = ""
- ) -> None:
- if wandb.run is not None:
- wandb.run._tensorboard_callback(logdir, save=save, root_logdir=root_logdir)
|