log.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. from __future__ import annotations
  2. import io
  3. import re
  4. import time
  5. from typing import TYPE_CHECKING, Any
  6. import wandb
  7. import wandb.util
  8. from wandb.sdk.lib import telemetry
  9. if TYPE_CHECKING:
  10. import numpy as np
  11. from wandb.sdk.internal.tb_watcher import TBHistory
  12. # We have at least the default namestep and a global step to track
  13. # TODO: reset this structure on wandb.finish
  14. STEPS: dict[str, dict[str, Any]] = {
  15. "": {"step": 0},
  16. "global": {"step": 0, "last_log": None},
  17. }
  18. # TODO(cling): Set these when tensorboard behavior is configured.
  19. # We support rate limited logging by setting this to number of seconds,
  20. # can be a floating point.
  21. RATE_LIMIT_SECONDS: float | int | None = None
  22. IGNORE_KINDS = ["graphs"]
  23. tensor_util = wandb.util.get_module("tensorboard.util.tensor_util")
  24. # prefer tensorboard, fallback to protobuf in tensorflow when tboard isn't available
  25. pb = wandb.util.get_module(
  26. "tensorboard.compat.proto.summary_pb2"
  27. ) or wandb.util.get_module("tensorflow.core.framework.summary_pb2")
  28. Summary = pb.Summary if pb else None
  29. def make_ndarray(tensor: Any) -> np.ndarray | None:
  30. if tensor_util:
  31. res = tensor_util.make_ndarray(tensor)
  32. # Tensorboard can log generic objects, and we don't want to save them
  33. if res.dtype == "object":
  34. return None
  35. else:
  36. return res # type: ignore
  37. else:
  38. wandb.termwarn(
  39. "Can't convert tensor summary, upgrade tensorboard with `pip"
  40. " install tensorboard --upgrade`"
  41. )
  42. return None
  43. def namespaced_tag(tag: str, namespace: str = "") -> str:
  44. if not namespace:
  45. return tag
  46. else:
  47. return namespace + "/" + tag
  48. def history_image_key(key: str, namespace: str = "") -> str:
  49. """Convert invalid filesystem characters to _ for use in History keys.
  50. Unfortunately this means currently certain image keys will collide silently. We
  51. implement this mapping up here in the TensorFlow stuff rather than in the History
  52. stuff so that we don't have to store a mapping anywhere from the original keys to
  53. the safe ones.
  54. """
  55. return namespaced_tag(re.sub(r"[/\\]", "_", key), namespace)
  56. def tf_summary_to_dict( # noqa: C901
  57. tf_summary_str_or_pb: Any, namespace: str = ""
  58. ) -> dict[str, Any] | None:
  59. """Convert a Tensorboard Summary to a dictionary.
  60. Accepts a tensorflow.summary.Summary, one encoded as a string,
  61. or a list of such encoded as strings.
  62. """
  63. values = {}
  64. if hasattr(tf_summary_str_or_pb, "summary"):
  65. summary_pb = tf_summary_str_or_pb.summary
  66. values[namespaced_tag("global_step", namespace)] = tf_summary_str_or_pb.step
  67. values["_timestamp"] = tf_summary_str_or_pb.wall_time
  68. elif isinstance(tf_summary_str_or_pb, (str, bytes, bytearray)):
  69. summary_pb = Summary()
  70. summary_pb.ParseFromString(tf_summary_str_or_pb)
  71. elif hasattr(tf_summary_str_or_pb, "__iter__"):
  72. summary_pb = [Summary() for _ in range(len(tf_summary_str_or_pb))]
  73. for i, summary in enumerate(tf_summary_str_or_pb):
  74. summary_pb[i].ParseFromString(summary)
  75. if i > 0:
  76. summary_pb[0].MergeFrom(summary_pb[i])
  77. summary_pb = summary_pb[0]
  78. else:
  79. summary_pb = tf_summary_str_or_pb
  80. if not hasattr(summary_pb, "value") or len(summary_pb.value) == 0:
  81. # Ignore these, caller is responsible for handling None
  82. return None
  83. def encode_images(_img_strs: list[bytes], _value: Any) -> None:
  84. try:
  85. from PIL import Image
  86. except ImportError:
  87. wandb.termwarn(
  88. "Install pillow if you are logging images with Tensorboard. "
  89. "To install, run `pip install pillow`.",
  90. repeat=False,
  91. )
  92. return None
  93. if len(_img_strs) == 0:
  94. return None
  95. images: list[wandb.Video | wandb.Image] = []
  96. for _img_str in _img_strs:
  97. # Supports gifs from TensorboardX
  98. if _img_str.startswith(b"GIF"):
  99. images.append(wandb.Video(io.BytesIO(_img_str), format="gif"))
  100. else:
  101. images.append(wandb.Image(Image.open(io.BytesIO(_img_str))))
  102. tag_idx = _value.tag.rsplit("/", 1)
  103. if len(tag_idx) > 1 and tag_idx[1].isdigit():
  104. tag, idx = tag_idx
  105. values.setdefault(history_image_key(tag, namespace), []).extend(images)
  106. else:
  107. values[history_image_key(_value.tag, namespace)] = images
  108. return None
  109. for value in summary_pb.value:
  110. kind = value.WhichOneof("value")
  111. if kind in IGNORE_KINDS:
  112. continue
  113. if kind == "simple_value":
  114. values[namespaced_tag(value.tag, namespace)] = value.simple_value
  115. elif kind == "tensor":
  116. plugin_name = value.metadata.plugin_data.plugin_name
  117. if plugin_name == "scalars" or plugin_name == "":
  118. values[namespaced_tag(value.tag, namespace)] = make_ndarray(
  119. value.tensor
  120. )
  121. elif plugin_name == "images":
  122. img_strs = value.tensor.string_val[2:] # First two items are dims.
  123. encode_images(img_strs, value)
  124. elif plugin_name == "histograms":
  125. # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/summary_v2.py#L15-L26
  126. ndarray = make_ndarray(value.tensor)
  127. if ndarray is None:
  128. continue
  129. shape = ndarray.shape
  130. counts = []
  131. bins = []
  132. if shape[0] > 1:
  133. bins.append(ndarray[0][0]) # Add the left most edge
  134. for v in ndarray:
  135. counts.append(v[2])
  136. bins.append(v[1]) # Add the right most edges
  137. elif shape[0] == 1:
  138. counts = [ndarray[0][2]]
  139. bins = ndarray[0][:2]
  140. if len(counts) > 0:
  141. try:
  142. # TODO: we should just re-bin if there are too many buckets
  143. values[namespaced_tag(value.tag, namespace)] = wandb.Histogram(
  144. np_histogram=(counts, bins) # type: ignore
  145. )
  146. except ValueError:
  147. wandb.termwarn(
  148. f'Not logging key "{namespaced_tag(value.tag, namespace)}". '
  149. f"Histograms must have fewer than {wandb.Histogram.MAX_LENGTH} bins",
  150. repeat=False,
  151. )
  152. elif plugin_name == "pr_curves":
  153. pr_curve_data = make_ndarray(value.tensor)
  154. if pr_curve_data is None:
  155. continue
  156. precision = pr_curve_data[-2, :].tolist()
  157. recall = pr_curve_data[-1, :].tolist()
  158. # TODO: (kdg) implement spec for showing additional info in tool tips
  159. # true_pos = pr_curve_data[1,:]
  160. # false_pos = pr_curve_data[2,:]
  161. # true_neg = pr_curve_data[1,:]
  162. # false_neg = pr_curve_data[1,:]
  163. # threshold = [1.0 / n for n in range(len(true_pos), 0, -1)]
  164. # min of each in case tensorboard ever changes their pr_curve
  165. # to allow for different length outputs
  166. data = []
  167. for i in range(min(len(precision), len(recall))):
  168. # drop additional threshold values if they exist
  169. if precision[i] != 0 or recall[i] != 0:
  170. data.append((recall[i], precision[i]))
  171. # sort data so custom chart looks the same as tb generated pr curve
  172. # ascending recall, descending precision for the same recall values
  173. data = sorted(data, key=lambda x: (x[0], -x[1]))
  174. data_table = wandb.Table(data=data, columns=["recall", "precision"])
  175. name = namespaced_tag(value.tag, namespace)
  176. values[name] = wandb.plot_table(
  177. "wandb/line/v0",
  178. data_table,
  179. {"x": "recall", "y": "precision"},
  180. {"title": f"{name} Precision v. Recall"},
  181. )
  182. elif kind == "image":
  183. img_str = value.image.encoded_image_string
  184. encode_images([img_str], value)
  185. # Coming soon...
  186. # elif kind == "audio":
  187. # audio = wandb.Audio(
  188. # six.BytesIO(value.audio.encoded_audio_string),
  189. # sample_rate=value.audio.sample_rate,
  190. # content_type=value.audio.content_type,
  191. # )
  192. elif kind == "histo":
  193. tag = namespaced_tag(value.tag, namespace)
  194. if len(value.histo.bucket_limit) >= 3:
  195. first = (
  196. value.histo.bucket_limit[0]
  197. + value.histo.bucket_limit[0]
  198. - value.histo.bucket_limit[1]
  199. )
  200. last = (
  201. value.histo.bucket_limit[-2]
  202. + value.histo.bucket_limit[-2]
  203. - value.histo.bucket_limit[-3]
  204. )
  205. np_histogram = (
  206. list(value.histo.bucket),
  207. [first] + value.histo.bucket_limit[:-1] + [last],
  208. )
  209. try:
  210. # TODO: we should just re-bin if there are too many buckets
  211. values[tag] = wandb.Histogram(np_histogram=np_histogram) # type: ignore
  212. except ValueError:
  213. wandb.termwarn(
  214. f"Not logging key {tag!r}. "
  215. f"Histograms must have fewer than {wandb.Histogram.MAX_LENGTH} bins",
  216. repeat=False,
  217. )
  218. else:
  219. # TODO: is there a case where we can render this?
  220. wandb.termwarn(
  221. f"Not logging key {tag!r}. Found a histogram with only 2 bins.",
  222. repeat=False,
  223. )
  224. # TODO(jhr): figure out how to share this between userspace and internal process or dont
  225. # elif value.tag == "_hparams_/session_start_info":
  226. # if wandb.util.get_module("tensorboard.plugins.hparams"):
  227. # from tensorboard.plugins.hparams import plugin_data_pb2
  228. #
  229. # plugin_data = plugin_data_pb2.HParamsPluginData() #
  230. # plugin_data.ParseFromString(value.metadata.plugin_data.content)
  231. # for key, param in six.iteritems(plugin_data.session_start_info.hparams):
  232. # if not wandb.run.config.get(key):
  233. # wandb.run.config[key] = (
  234. # param.number_value or param.string_value or param.bool_value
  235. # )
  236. # else:
  237. # wandb.termerror(
  238. # "Received hparams tf.summary, but could not import "
  239. # "the hparams plugin from tensorboard"
  240. # )
  241. return values
  242. def reset_state() -> None:
  243. """Internal method for resetting state, called by wandb.finish()."""
  244. global STEPS
  245. STEPS = {"": {"step": 0}, "global": {"step": 0, "last_log": None}}
  246. def _log(
  247. tf_summary_str_or_pb: Any,
  248. history: TBHistory | None = None,
  249. step: int = 0,
  250. namespace: str = "",
  251. **kwargs: Any,
  252. ) -> None:
  253. """Logs a tfsummary to wandb.
  254. Can accept a tf summary string or parsed event. Will use wandb.run.history unless a
  255. history object is passed. Can optionally namespace events. Results are committed
  256. when step increases for this namespace.
  257. NOTE: This assumes that events being passed in are in chronological order
  258. """
  259. global STEPS
  260. global RATE_LIMIT_SECONDS
  261. # To handle multiple global_steps, we keep track of them here instead
  262. # of the global log
  263. last_step = STEPS.get(namespace, {"step": 0})
  264. # Commit our existing data if this namespace increased its step
  265. commit = False
  266. if last_step["step"] < step:
  267. commit = True
  268. log_dict = tf_summary_to_dict(tf_summary_str_or_pb, namespace)
  269. if log_dict is None:
  270. # not an event, just return
  271. return
  272. # Pass timestamp to history for loading historic data
  273. timestamp = log_dict.get("_timestamp", time.time())
  274. # Store our initial timestamp
  275. if STEPS["global"]["last_log"] is None:
  276. STEPS["global"]["last_log"] = timestamp
  277. # Rollup events that share the same step across namespaces
  278. if commit and step == STEPS["global"]["step"]:
  279. commit = False
  280. # Always add the biggest global_step key for non-default namespaces
  281. if step > STEPS["global"]["step"]:
  282. STEPS["global"]["step"] = step
  283. if namespace != "":
  284. log_dict["global_step"] = STEPS["global"]["step"]
  285. # Keep internal step counter
  286. STEPS[namespace] = {"step": step}
  287. if commit:
  288. # Only commit our data if we're below the rate limit or don't have one
  289. if (
  290. RATE_LIMIT_SECONDS is None
  291. or timestamp - STEPS["global"]["last_log"] >= RATE_LIMIT_SECONDS
  292. ):
  293. if history is None:
  294. if wandb.run is not None:
  295. wandb.run._log({})
  296. else:
  297. history.add({})
  298. STEPS["global"]["last_log"] = timestamp
  299. if history is None:
  300. if wandb.run is not None:
  301. wandb.run._log(log_dict, commit=False)
  302. else:
  303. history._row_update(log_dict)
  304. def log(tf_summary_str_or_pb: Any, step: int = 0, namespace: str = "") -> None:
  305. if wandb.run is None:
  306. raise wandb.Error(
  307. "You must call `wandb.init()` before calling `wandb.tensorflow.log`"
  308. )
  309. with telemetry.context() as tel:
  310. tel.feature.tensorboard_log = True
  311. _log(tf_summary_str_or_pb, namespace=namespace, step=step)