monkeypatch.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. """monkeypatch: patch code to add tensorboard hooks."""
  2. from __future__ import annotations
  3. import os
  4. import re
  5. import socket
  6. from typing import Any
  7. import wandb
  8. import wandb.util
  9. TENSORBOARD_C_MODULE = "tensorflow.python.ops.gen_summary_ops"
  10. TENSORBOARD_X_MODULE = "tensorboardX.writer"
  11. TENSORFLOW_PY_MODULE = "tensorflow.python.summary.writer.writer"
  12. TENSORBOARD_WRITER_MODULE = "tensorboard.summary.writer.event_file_writer"
  13. TENSORBOARD_PYTORCH_MODULE = "torch.utils.tensorboard.writer"
  14. def unpatch() -> None:
  15. for module, method in wandb.patched["tensorboard"]:
  16. writer = wandb.util.get_module(module, lazy=False)
  17. setattr(writer, method, getattr(writer, f"orig_{method}"))
  18. wandb.patched["tensorboard"] = []
  19. def patch(
  20. save: bool = True,
  21. tensorboard_x: bool | None = None,
  22. pytorch: bool | None = None,
  23. root_logdir: str = "",
  24. ) -> None:
  25. if len(wandb.patched["tensorboard"]) > 0:
  26. raise ValueError(
  27. "Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; "
  28. "remove `sync_tensorboard=True` from `wandb.init`; "
  29. "or only call `wandb.tensorboard.patch` once."
  30. )
  31. # TODO: Some older versions of tensorflow don't require tensorboard to be present.
  32. # we may want to lift this requirement, but it's safer to have it for now
  33. wandb.util.get_module(
  34. "tensorboard", required="Please install tensorboard package", lazy=False
  35. )
  36. c_writer = wandb.util.get_module(TENSORBOARD_C_MODULE, lazy=False)
  37. py_writer = wandb.util.get_module(TENSORFLOW_PY_MODULE, lazy=False)
  38. tb_writer = wandb.util.get_module(TENSORBOARD_WRITER_MODULE, lazy=False)
  39. pt_writer = wandb.util.get_module(TENSORBOARD_PYTORCH_MODULE, lazy=False)
  40. tbx_writer = wandb.util.get_module(TENSORBOARD_X_MODULE, lazy=False)
  41. if not pytorch and not tensorboard_x and c_writer:
  42. _patch_tensorflow2(
  43. writer=c_writer,
  44. module=TENSORBOARD_C_MODULE,
  45. save=save,
  46. root_logdir=root_logdir,
  47. )
  48. # This is for tensorflow <= 1.15 (tf.compat.v1.summary.FileWriter)
  49. if py_writer:
  50. _patch_file_writer(
  51. writer=py_writer,
  52. module=TENSORFLOW_PY_MODULE,
  53. save=save,
  54. root_logdir=root_logdir,
  55. )
  56. if tb_writer:
  57. _patch_file_writer(
  58. writer=tb_writer,
  59. module=TENSORBOARD_WRITER_MODULE,
  60. save=save,
  61. root_logdir=root_logdir,
  62. )
  63. if pt_writer:
  64. _patch_file_writer(
  65. writer=pt_writer,
  66. module=TENSORBOARD_PYTORCH_MODULE,
  67. save=save,
  68. root_logdir=root_logdir,
  69. )
  70. if tbx_writer:
  71. _patch_file_writer(
  72. writer=tbx_writer,
  73. module=TENSORBOARD_X_MODULE,
  74. save=save,
  75. root_logdir=root_logdir,
  76. )
  77. if not c_writer and not tb_writer and not tb_writer:
  78. wandb.termerror("Unsupported tensorboard configuration")
  79. def _patch_tensorflow2(
  80. writer: Any,
  81. module: Any,
  82. save: bool = True,
  83. root_logdir: str = "",
  84. ) -> None:
  85. # This configures TensorFlow 2 style Tensorboard logging
  86. old_csfw_func = writer.create_summary_file_writer
  87. logdir_hist = []
  88. def new_csfw_func(*args: Any, **kwargs: Any) -> Any:
  89. logdir = (
  90. kwargs["logdir"].numpy().decode("utf8")
  91. if hasattr(kwargs["logdir"], "numpy")
  92. else kwargs["logdir"]
  93. )
  94. logdir_hist.append(logdir)
  95. root_logdir_arg = root_logdir
  96. if len(set(logdir_hist)) > 1 and root_logdir == "":
  97. wandb.termwarn(
  98. "When using several event log directories, "
  99. 'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
  100. )
  101. # if the logdir contains the hostname, the writer was not given a logdir.
  102. # In this case, the generated logdir
  103. # is generated and ends with the hostname, update the root_logdir to match.
  104. hostname = socket.gethostname()
  105. search = re.search(rf"-\d+_{hostname}", logdir)
  106. if search:
  107. root_logdir_arg = logdir[: search.span()[1]]
  108. elif root_logdir is not None and not os.path.abspath(logdir).startswith(
  109. os.path.abspath(root_logdir)
  110. ):
  111. wandb.termwarn(
  112. "Found log directory outside of given root_logdir, "
  113. f"dropping given root_logdir for event file in {logdir}"
  114. )
  115. root_logdir_arg = ""
  116. _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
  117. return old_csfw_func(*args, **kwargs)
  118. writer.orig_create_summary_file_writer = old_csfw_func
  119. writer.create_summary_file_writer = new_csfw_func
  120. wandb.patched["tensorboard"].append([module, "create_summary_file_writer"])
  121. def _patch_file_writer(
  122. writer: Any,
  123. module: Any,
  124. save: bool = True,
  125. root_logdir: str = "",
  126. ) -> None:
  127. # This configures non-TensorFlow Tensorboard logging, or tensorflow <= 1.15
  128. logdir_hist = []
  129. class TBXEventFileWriter(writer.EventFileWriter):
  130. def __init__(self, logdir: str, *args: Any, **kwargs: Any) -> None:
  131. logdir_hist.append(logdir)
  132. root_logdir_arg = root_logdir
  133. if len(set(logdir_hist)) > 1 and root_logdir == "":
  134. wandb.termwarn(
  135. "When using several event log directories, "
  136. 'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
  137. )
  138. # if the logdir contains the hostname, the writer was not given a logdir.
  139. # In this case, the logdir is generated and ends with the hostname,
  140. # update the root_logdir to match.
  141. hostname = socket.gethostname()
  142. search = re.search(rf"-\d+_{hostname}", logdir)
  143. if search:
  144. root_logdir_arg = logdir[: search.span()[1]]
  145. elif root_logdir is not None and not os.path.abspath(logdir).startswith(
  146. os.path.abspath(root_logdir)
  147. ):
  148. wandb.termwarn(
  149. "Found log directory outside of given root_logdir, "
  150. f"dropping given root_logdir for event file in {logdir}"
  151. )
  152. root_logdir_arg = ""
  153. _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
  154. super().__init__(logdir, *args, **kwargs)
  155. writer.orig_EventFileWriter = writer.EventFileWriter
  156. writer.EventFileWriter = TBXEventFileWriter
  157. wandb.patched["tensorboard"].append([module, "EventFileWriter"])
  158. def _notify_tensorboard_logdir(
  159. logdir: str, save: bool = True, root_logdir: str = ""
  160. ) -> None:
  161. if wandb.run is not None:
  162. wandb.run._tensorboard_callback(logdir, save=save, root_logdir=root_logdir)