summary.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033
  1. # mypy: allow-untyped-defs
  2. import json
  3. import logging
  4. import struct
  5. from typing import Any
  6. import torch
  7. import numpy as np
  8. from google.protobuf import struct_pb2
  9. from tensorboard.compat.proto.summary_pb2 import (
  10. HistogramProto,
  11. Summary,
  12. SummaryMetadata,
  13. )
  14. from tensorboard.compat.proto.tensor_pb2 import TensorProto
  15. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  16. from tensorboard.plugins.custom_scalar import layout_pb2
  17. from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
  18. from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
  19. from ._convert_np import make_np
  20. from ._utils import _prepare_video, convert_to_HWC
  21. __all__ = [
  22. "half_to_int",
  23. "int_to_half",
  24. "hparams",
  25. "scalar",
  26. "histogram_raw",
  27. "histogram",
  28. "make_histogram",
  29. "image",
  30. "image_boxes",
  31. "draw_boxes",
  32. "make_image",
  33. "video",
  34. "make_video",
  35. "audio",
  36. "custom_scalars",
  37. "text",
  38. "tensor_proto",
  39. "pr_curve_raw",
  40. "pr_curve",
  41. "compute_curve",
  42. "mesh",
  43. ]
  44. logger = logging.getLogger(__name__)
  45. def half_to_int(f: float) -> int:
  46. """Casts a half-precision float value into an integer.
  47. Converts a half precision floating point value, such as `torch.half` or
  48. `torch.bfloat16`, into an integer value which can be written into the
  49. half_val field of a TensorProto for storage.
  50. To undo the effects of this conversion, use int_to_half().
  51. """
  52. buf = struct.pack("f", f)
  53. return struct.unpack("i", buf)[0]
  54. def int_to_half(i: int) -> float:
  55. """Casts an integer value to a half-precision float.
  56. Converts an integer value obtained from half_to_int back into a floating
  57. point value.
  58. """
  59. buf = struct.pack("i", i)
  60. return struct.unpack("f", buf)[0]
  61. def _tensor_to_half_val(t: torch.Tensor) -> list[int]:
  62. return [half_to_int(x) for x in t.flatten().tolist()]
  63. def _tensor_to_complex_val(t: torch.Tensor) -> list[float]:
  64. return torch.view_as_real(t).flatten().tolist()
  65. def _tensor_to_list(t: torch.Tensor) -> list[Any]:
  66. return t.flatten().tolist()
  67. # type maps: torch.Tensor type -> (protobuf type, protobuf val field)
  68. _TENSOR_TYPE_MAP = {
  69. torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
  70. torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
  71. torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
  72. torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
  73. torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
  74. torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
  75. torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
  76. torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
  77. torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
  78. torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
  79. torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
  80. torch.short: ("DT_INT16", "int_val", _tensor_to_list),
  81. torch.int: ("DT_INT32", "int_val", _tensor_to_list),
  82. torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
  83. torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
  84. torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
  85. torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
  86. torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
  87. torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
  88. torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
  89. torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
  90. torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
  91. torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
  92. torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
  93. torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
  94. torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
  95. }
  96. def _calc_scale_factor(tensor) -> int:
  97. converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
  98. return 1 if converted.dtype == np.uint8 else 255
  99. def _draw_single_box(
  100. image,
  101. xmin,
  102. ymin,
  103. xmax,
  104. ymax,
  105. display_str,
  106. color="black",
  107. color_text="black",
  108. thickness=2,
  109. ):
  110. from PIL import ImageDraw, ImageFont
  111. font = ImageFont.load_default()
  112. draw = ImageDraw.Draw(image)
  113. (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
  114. draw.line(
  115. [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
  116. width=thickness,
  117. fill=color,
  118. )
  119. if display_str:
  120. text_bottom = bottom
  121. # Reverse list and print from bottom to top.
  122. _left, _top, _right, _bottom = font.getbbox(display_str)
  123. text_width, text_height = _right - _left, _bottom - _top
  124. margin = np.ceil(0.05 * text_height)
  125. draw.rectangle(
  126. [
  127. (left, text_bottom - text_height - 2 * margin),
  128. (left + text_width, text_bottom),
  129. ],
  130. fill=color,
  131. )
  132. draw.text(
  133. (left + margin, text_bottom - text_height - margin),
  134. display_str,
  135. fill=color_text,
  136. font=font,
  137. )
  138. return image
  139. def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
  140. """Output three `Summary` protocol buffers needed by hparams plugin.
  141. `Experiment` keeps the metadata of an experiment, such as the name of the
  142. hyperparameters and the name of the metrics.
  143. `SessionStartInfo` keeps key-value pairs of the hyperparameters
  144. `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
  145. Args:
  146. hparam_dict: A dictionary that contains names of the hyperparameters
  147. and their values.
  148. metric_dict: A dictionary that contains names of the metrics
  149. and their values.
  150. hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
  151. contains names of the hyperparameters and all discrete values they can hold
  152. Returns:
  153. The `Summary` protobufs for Experiment, SessionStartInfo and
  154. SessionEndInfo
  155. """
  156. import torch
  157. from tensorboard.plugins.hparams.api_pb2 import (
  158. DataType,
  159. Experiment,
  160. HParamInfo,
  161. MetricInfo,
  162. MetricName,
  163. Status,
  164. )
  165. from tensorboard.plugins.hparams.metadata import (
  166. EXPERIMENT_TAG,
  167. PLUGIN_DATA_VERSION,
  168. PLUGIN_NAME,
  169. SESSION_END_INFO_TAG,
  170. SESSION_START_INFO_TAG,
  171. )
  172. from tensorboard.plugins.hparams.plugin_data_pb2 import (
  173. HParamsPluginData,
  174. SessionEndInfo,
  175. SessionStartInfo,
  176. )
  177. # TODO: expose other parameters in the future.
  178. # hp = HParamInfo(name='lr',display_name='learning rate',
  179. # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
  180. # max_value=100))
  181. # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
  182. # description='', dataset_type=DatasetType.DATASET_VALIDATION)
  183. # exp = Experiment(name='123', description='456', time_created_secs=100.0,
  184. # hparam_infos=[hp], metric_infos=[mt], user='tw')
  185. if not isinstance(hparam_dict, dict):
  186. logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
  187. raise TypeError(
  188. "parameter: hparam_dict should be a dictionary, nothing logged."
  189. )
  190. if not isinstance(metric_dict, dict):
  191. logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
  192. raise TypeError(
  193. "parameter: metric_dict should be a dictionary, nothing logged."
  194. )
  195. hparam_domain_discrete = hparam_domain_discrete or {}
  196. if not isinstance(hparam_domain_discrete, dict):
  197. raise TypeError(
  198. "parameter: hparam_domain_discrete should be a dictionary, nothing logged."
  199. )
  200. for k, v in hparam_domain_discrete.items():
  201. if (
  202. k not in hparam_dict
  203. or not isinstance(v, list)
  204. or not all(isinstance(d, type(hparam_dict[k])) for d in v)
  205. ):
  206. raise TypeError(
  207. f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]."
  208. )
  209. hps = []
  210. ssi = SessionStartInfo()
  211. for k, v in hparam_dict.items():
  212. if v is None:
  213. continue
  214. if isinstance(v, (int, float)):
  215. ssi.hparams[k].number_value = v
  216. if k in hparam_domain_discrete:
  217. domain_discrete: struct_pb2.ListValue | None = struct_pb2.ListValue(
  218. values=[
  219. struct_pb2.Value(number_value=d)
  220. for d in hparam_domain_discrete[k]
  221. ]
  222. )
  223. else:
  224. domain_discrete = None
  225. hps.append(
  226. HParamInfo(
  227. name=k,
  228. # pyrefly: ignore [missing-attribute]
  229. type=DataType.Value("DATA_TYPE_FLOAT64"),
  230. domain_discrete=domain_discrete,
  231. )
  232. )
  233. continue
  234. if isinstance(v, str):
  235. ssi.hparams[k].string_value = v
  236. if k in hparam_domain_discrete:
  237. domain_discrete = struct_pb2.ListValue(
  238. values=[
  239. struct_pb2.Value(string_value=d)
  240. for d in hparam_domain_discrete[k]
  241. ]
  242. )
  243. else:
  244. domain_discrete = None
  245. hps.append(
  246. HParamInfo(
  247. name=k,
  248. # pyrefly: ignore [missing-attribute]
  249. type=DataType.Value("DATA_TYPE_STRING"),
  250. domain_discrete=domain_discrete,
  251. )
  252. )
  253. continue
  254. if isinstance(v, bool):
  255. ssi.hparams[k].bool_value = v
  256. if k in hparam_domain_discrete:
  257. domain_discrete = struct_pb2.ListValue(
  258. values=[
  259. struct_pb2.Value(bool_value=d)
  260. for d in hparam_domain_discrete[k]
  261. ]
  262. )
  263. else:
  264. domain_discrete = None
  265. hps.append(
  266. HParamInfo(
  267. name=k,
  268. # pyrefly: ignore [missing-attribute]
  269. type=DataType.Value("DATA_TYPE_BOOL"),
  270. domain_discrete=domain_discrete,
  271. )
  272. )
  273. continue
  274. if isinstance(v, torch.Tensor):
  275. v = make_np(v)[0]
  276. ssi.hparams[k].number_value = v
  277. # pyrefly: ignore [missing-attribute]
  278. hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
  279. continue
  280. raise ValueError(
  281. "value should be one of int, float, str, bool, or torch.Tensor"
  282. )
  283. content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
  284. smd = SummaryMetadata(
  285. # pyrefly: ignore [missing-attribute]
  286. plugin_data=SummaryMetadata.PluginData(
  287. plugin_name=PLUGIN_NAME, content=content.SerializeToString()
  288. )
  289. )
  290. # pyrefly: ignore [missing-attribute]
  291. ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
  292. mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict]
  293. exp = Experiment(hparam_infos=hps, metric_infos=mts)
  294. content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
  295. smd = SummaryMetadata(
  296. # pyrefly: ignore [missing-attribute]
  297. plugin_data=SummaryMetadata.PluginData(
  298. plugin_name=PLUGIN_NAME, content=content.SerializeToString()
  299. )
  300. )
  301. # pyrefly: ignore [missing-attribute]
  302. exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
  303. # pyrefly: ignore [missing-attribute]
  304. sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
  305. content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
  306. smd = SummaryMetadata(
  307. # pyrefly: ignore [missing-attribute]
  308. plugin_data=SummaryMetadata.PluginData(
  309. plugin_name=PLUGIN_NAME, content=content.SerializeToString()
  310. )
  311. )
  312. # pyrefly: ignore [missing-attribute]
  313. sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
  314. return exp, ssi, sei
  315. def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
  316. """Output a `Summary` protocol buffer containing a single scalar value.
  317. The generated Summary has a Tensor.proto containing the input Tensor.
  318. Args:
  319. name: A name for the generated node. Will also serve as the series name in
  320. TensorBoard.
  321. tensor: A real numeric Tensor containing a single value.
  322. collections: Optional list of graph collections keys. The new summary op is
  323. added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
  324. new_style: Whether to use new style (tensor field) or old style (simple_value
  325. field). New style could lead to faster data loading.
  326. Returns:
  327. A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
  328. Raises:
  329. ValueError: If tensor has the wrong shape or type.
  330. """
  331. tensor = make_np(tensor).squeeze()
  332. if tensor.ndim != 0:
  333. raise AssertionError(f"Tensor should contain one element (0 dimensions). \
  334. Was given size: {tensor.size} and {tensor.ndim} dimensions.")
  335. # python float is double precision in numpy
  336. scalar = float(tensor)
  337. if new_style:
  338. tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
  339. if double_precision:
  340. tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
  341. # pyrefly: ignore [missing-attribute]
  342. plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
  343. smd = SummaryMetadata(plugin_data=plugin_data)
  344. return Summary(
  345. value=[
  346. # pyrefly: ignore [missing-attribute]
  347. Summary.Value(
  348. tag=name,
  349. tensor=tensor_proto,
  350. metadata=smd,
  351. )
  352. ]
  353. )
  354. else:
  355. # pyrefly: ignore [missing-attribute]
  356. return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
  357. def tensor_proto(tag, tensor):
  358. """Outputs a `Summary` protocol buffer containing the full tensor.
  359. The generated Summary has a Tensor.proto containing the input Tensor.
  360. Args:
  361. tag: A name for the generated node. Will also serve as the series name in
  362. TensorBoard.
  363. tensor: Tensor to be converted to protobuf
  364. Returns:
  365. A tensor protobuf in a `Summary` protobuf.
  366. Raises:
  367. ValueError: If tensor is too big to be converted to protobuf, or
  368. tensor data type is not supported
  369. """
  370. if tensor.numel() * tensor.itemsize >= (1 << 31):
  371. raise ValueError(
  372. "tensor is bigger than protocol buffer's hard limit of 2GB in size"
  373. )
  374. if tensor.dtype in _TENSOR_TYPE_MAP:
  375. dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
  376. tensor_proto = TensorProto(
  377. **{
  378. "dtype": dtype,
  379. "tensor_shape": TensorShapeProto(
  380. # pyrefly: ignore [missing-attribute]
  381. dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
  382. ),
  383. field_name: conversion_fn(tensor),
  384. },
  385. )
  386. else:
  387. raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")
  388. # pyrefly: ignore [missing-attribute]
  389. plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
  390. smd = SummaryMetadata(plugin_data=plugin_data)
  391. # pyrefly: ignore [missing-attribute]
  392. return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])
  393. def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
  394. # pylint: disable=line-too-long
  395. """Output a `Summary` protocol buffer with a histogram.
  396. The generated
  397. [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
  398. has one summary value containing a histogram for `values`.
  399. Args:
  400. name: A name for the generated node. Will also serve as a series name in
  401. TensorBoard.
  402. min: A float or int min value
  403. max: A float or int max value
  404. num: Int number of values
  405. sum: Float or int sum of all values
  406. sum_squares: Float or int sum of squares for all values
  407. bucket_limits: A numeric `Tensor` with upper value per bucket
  408. bucket_counts: A numeric `Tensor` with number of values per bucket
  409. Returns:
  410. A scalar `Tensor` of type `string`. The serialized `Summary` protocol
  411. buffer.
  412. """
  413. hist = HistogramProto(
  414. min=min,
  415. max=max,
  416. num=num,
  417. sum=sum,
  418. sum_squares=sum_squares,
  419. bucket_limit=bucket_limits,
  420. bucket=bucket_counts,
  421. )
  422. # pyrefly: ignore [missing-attribute]
  423. return Summary(value=[Summary.Value(tag=name, histo=hist)])
  424. def histogram(name, values, bins, max_bins=None):
  425. # pylint: disable=line-too-long
  426. """Output a `Summary` protocol buffer with a histogram.
  427. The generated
  428. [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
  429. has one summary value containing a histogram for `values`.
  430. This op reports an `InvalidArgument` error if any value is not finite.
  431. Args:
  432. name: A name for the generated node. Will also serve as a series name in
  433. TensorBoard.
  434. values: A real numeric `Tensor`. Any shape. Values to use to
  435. build the histogram.
  436. Returns:
  437. A scalar `Tensor` of type `string`. The serialized `Summary` protocol
  438. buffer.
  439. """
  440. values = make_np(values)
  441. hist = make_histogram(values.astype(float), bins, max_bins)
  442. # pyrefly: ignore [missing-attribute]
  443. return Summary(value=[Summary.Value(tag=name, histo=hist)])
  444. def make_histogram(values, bins, max_bins=None):
  445. """Convert values into a histogram proto using logic from histogram.cc."""
  446. if values.size == 0:
  447. raise ValueError("The input has no element.")
  448. values = values.reshape(-1)
  449. counts, limits = np.histogram(values, bins=bins)
  450. num_bins = len(counts)
  451. if max_bins is not None and num_bins > max_bins:
  452. subsampling = num_bins // max_bins
  453. subsampling_remainder = num_bins % subsampling
  454. if subsampling_remainder != 0:
  455. # pyrefly: ignore [no-matching-overload]
  456. counts = np.pad(
  457. counts,
  458. pad_width=[[0, subsampling - subsampling_remainder]],
  459. mode="constant",
  460. constant_values=0,
  461. )
  462. counts = counts.reshape(-1, subsampling).sum(axis=-1)
  463. new_limits = np.empty((counts.size + 1,), limits.dtype)
  464. new_limits[:-1] = limits[:-1:subsampling]
  465. new_limits[-1] = limits[-1]
  466. limits = new_limits
  467. # Find the first and the last bin defining the support of the histogram:
  468. cum_counts = np.cumsum(np.greater(counts, 0))
  469. start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
  470. start = int(start)
  471. end = int(end) + 1
  472. del cum_counts
  473. # TensorBoard only includes the right bin limits. To still have the leftmost limit
  474. # included, we include an empty bin left.
  475. # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
  476. # first nonzero-count bin:
  477. counts = (
  478. counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
  479. )
  480. limits = limits[start : end + 1]
  481. if counts.size == 0 or limits.size == 0:
  482. raise ValueError("The histogram is empty, please file a bug report.")
  483. sum_sq = values.dot(values)
  484. return HistogramProto(
  485. min=values.min(),
  486. max=values.max(),
  487. num=len(values),
  488. sum=values.sum(),
  489. sum_squares=sum_sq,
  490. bucket_limit=limits.tolist(),
  491. bucket=counts.tolist(),
  492. )
  493. def image(tag, tensor, rescale=1, dataformats="NCHW"):
  494. """Output a `Summary` protocol buffer with images.
  495. The summary has up to `max_images` summary values containing images. The
  496. images are built from `tensor` which must be 3-D with shape `[height, width,
  497. channels]` and where `channels` can be:
  498. * 1: `tensor` is interpreted as Grayscale.
  499. * 3: `tensor` is interpreted as RGB.
  500. * 4: `tensor` is interpreted as RGBA.
  501. The `name` in the outputted Summary.Value protobufs is generated based on the
  502. name, with a suffix depending on the max_outputs setting:
  503. * If `max_outputs` is 1, the summary value tag is '*name*/image'.
  504. * If `max_outputs` is greater than 1, the summary value tags are
  505. generated sequentially as '*name*/image/0', '*name*/image/1', etc.
  506. Args:
  507. tag: A name for the generated node. Will also serve as a series name in
  508. TensorBoard.
  509. tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
  510. channels]` where `channels` is 1, 3, or 4.
  511. 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
  512. The image() function will scale the image values to [0, 255] by applying
  513. a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
  514. will be clipped.
  515. Returns:
  516. A scalar `Tensor` of type `string`. The serialized `Summary` protocol
  517. buffer.
  518. """
  519. tensor = make_np(tensor)
  520. tensor = convert_to_HWC(tensor, dataformats)
  521. # Do not assume that user passes in values in [0, 255], use data type to detect
  522. scale_factor = _calc_scale_factor(tensor)
  523. tensor = tensor.astype(np.float32)
  524. tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
  525. image = make_image(tensor, rescale=rescale)
  526. # pyrefly: ignore [missing-attribute]
  527. return Summary(value=[Summary.Value(tag=tag, image=image)])
  528. def image_boxes(
  529. tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
  530. ):
  531. """Output a `Summary` protocol buffer with images."""
  532. tensor_image = make_np(tensor_image)
  533. tensor_image = convert_to_HWC(tensor_image, dataformats)
  534. tensor_boxes = make_np(tensor_boxes)
  535. tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
  536. image = make_image(
  537. tensor_image.clip(0, 255).astype(np.uint8),
  538. rescale=rescale,
  539. rois=tensor_boxes,
  540. labels=labels,
  541. )
  542. # pyrefly: ignore [missing-attribute]
  543. return Summary(value=[Summary.Value(tag=tag, image=image)])
  544. def draw_boxes(disp_image, boxes, labels=None):
  545. # xyxy format
  546. num_boxes = boxes.shape[0]
  547. list_gt = range(num_boxes)
  548. for i in list_gt:
  549. disp_image = _draw_single_box(
  550. disp_image,
  551. boxes[i, 0],
  552. boxes[i, 1],
  553. boxes[i, 2],
  554. boxes[i, 3],
  555. display_str=None if labels is None else labels[i],
  556. color="Red",
  557. )
  558. return disp_image
  559. def make_image(tensor, rescale=1, rois=None, labels=None):
  560. """Convert a numpy representation of an image to Image protobuf."""
  561. from PIL import Image
  562. height, width, channel = tensor.shape
  563. scaled_height = int(height * rescale)
  564. scaled_width = int(width * rescale)
  565. image = Image.fromarray(tensor)
  566. if rois is not None:
  567. image = draw_boxes(image, rois, labels=labels)
  568. ANTIALIAS = Image.Resampling.LANCZOS
  569. image = image.resize((scaled_width, scaled_height), ANTIALIAS)
  570. import io
  571. output = io.BytesIO()
  572. image.save(output, format="PNG")
  573. image_string = output.getvalue()
  574. output.close()
  575. # pyrefly: ignore [missing-attribute]
  576. return Summary.Image(
  577. height=height,
  578. width=width,
  579. colorspace=channel,
  580. encoded_image_string=image_string,
  581. )
  582. def video(tag, tensor, fps=4):
  583. tensor = make_np(tensor)
  584. tensor = _prepare_video(tensor)
  585. # If user passes in uint8, then we don't need to rescale by 255
  586. scale_factor = _calc_scale_factor(tensor)
  587. tensor = tensor.astype(np.float32)
  588. tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
  589. video = make_video(tensor, fps)
  590. # pyrefly: ignore [missing-attribute]
  591. return Summary(value=[Summary.Value(tag=tag, image=video)])
  592. def make_video(tensor, fps):
  593. try:
  594. import moviepy # noqa: F401
  595. except ImportError:
  596. print("add_video needs package moviepy")
  597. return
  598. try:
  599. from moviepy import editor as mpy
  600. except ImportError:
  601. print(
  602. "moviepy is installed, but can't import moviepy.editor.",
  603. "Some packages could be missing [imageio, requests]",
  604. )
  605. return
  606. import tempfile
  607. _t, h, w, c = tensor.shape
  608. # encode sequence of images into gif string
  609. clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
  610. with tempfile.NamedTemporaryFile(suffix=".gif") as f:
  611. filename = f.name
  612. try: # newer version of moviepy use logger instead of progress_bar argument.
  613. clip.write_gif(filename, verbose=False, logger=None)
  614. except TypeError:
  615. try: # older version of moviepy does not support progress_bar argument.
  616. clip.write_gif(filename, verbose=False, progress_bar=False)
  617. except TypeError:
  618. clip.write_gif(filename, verbose=False)
  619. f.seek(0)
  620. tensor_string = f.read()
  621. # pyrefly: ignore [missing-attribute]
  622. return Summary.Image(
  623. height=h, width=w, colorspace=c, encoded_image_string=tensor_string
  624. )
  625. def audio(tag, tensor, sample_rate=44100):
  626. array = make_np(tensor)
  627. array = array.squeeze()
  628. if abs(array).max() > 1:
  629. print("warning: audio amplitude out of range, auto clipped.")
  630. array = array.clip(-1, 1)
  631. if array.ndim != 1:
  632. raise AssertionError("input tensor should be 1 dimensional.")
  633. # pyrefly: ignore [no-matching-overload]
  634. array = (array * np.iinfo(np.int16).max).astype("<i2")
  635. import io
  636. import wave
  637. fio = io.BytesIO()
  638. with wave.open(fio, "wb") as wave_write:
  639. wave_write.setnchannels(1)
  640. wave_write.setsampwidth(2)
  641. wave_write.setframerate(sample_rate)
  642. wave_write.writeframes(array.data)
  643. audio_string = fio.getvalue()
  644. fio.close()
  645. # pyrefly: ignore [missing-attribute]
  646. audio = Summary.Audio(
  647. sample_rate=sample_rate,
  648. num_channels=1,
  649. length_frames=array.shape[-1],
  650. encoded_audio_string=audio_string,
  651. content_type="audio/wav",
  652. )
  653. # pyrefly: ignore [missing-attribute]
  654. return Summary(value=[Summary.Value(tag=tag, audio=audio)])
  655. def custom_scalars(layout):
  656. categories = []
  657. for k, v in layout.items():
  658. charts = []
  659. for chart_name, chart_metadata in v.items():
  660. tags = chart_metadata[1]
  661. if chart_metadata[0] == "Margin":
  662. if len(tags) != 3:
  663. raise AssertionError("len(tags) != 3")
  664. mgcc = layout_pb2.MarginChartContent(
  665. series=[
  666. # pyrefly: ignore [missing-attribute]
  667. layout_pb2.MarginChartContent.Series(
  668. value=tags[0], lower=tags[1], upper=tags[2]
  669. )
  670. ]
  671. )
  672. chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
  673. else:
  674. mlcc = layout_pb2.MultilineChartContent(tag=tags)
  675. chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
  676. charts.append(chart)
  677. categories.append(layout_pb2.Category(title=k, chart=charts))
  678. layout = layout_pb2.Layout(category=categories)
  679. # pyrefly: ignore [missing-attribute]
  680. plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
  681. smd = SummaryMetadata(plugin_data=plugin_data)
  682. tensor = TensorProto(
  683. dtype="DT_STRING",
  684. string_val=[layout.SerializeToString()],
  685. tensor_shape=TensorShapeProto(),
  686. )
  687. return Summary(
  688. value=[
  689. # pyrefly: ignore [missing-attribute]
  690. Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
  691. ]
  692. )
  693. def text(tag, text):
  694. # pyrefly: ignore [missing-attribute]
  695. plugin_data = SummaryMetadata.PluginData(
  696. plugin_name="text", content=TextPluginData(version=0).SerializeToString()
  697. )
  698. smd = SummaryMetadata(plugin_data=plugin_data)
  699. tensor = TensorProto(
  700. dtype="DT_STRING",
  701. string_val=[text.encode(encoding="utf_8")],
  702. # pyrefly: ignore [missing-attribute]
  703. tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
  704. )
  705. return Summary(
  706. # pyrefly: ignore [missing-attribute]
  707. value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
  708. )
  709. def pr_curve_raw(
  710. tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
  711. ):
  712. if num_thresholds > 127: # weird, value > 127 breaks protobuf
  713. num_thresholds = 127
  714. data = np.stack((tp, fp, tn, fn, precision, recall))
  715. pr_curve_plugin_data = PrCurvePluginData(
  716. version=0, num_thresholds=num_thresholds
  717. ).SerializeToString()
  718. # pyrefly: ignore [missing-attribute]
  719. plugin_data = SummaryMetadata.PluginData(
  720. plugin_name="pr_curves", content=pr_curve_plugin_data
  721. )
  722. smd = SummaryMetadata(plugin_data=plugin_data)
  723. tensor = TensorProto(
  724. dtype="DT_FLOAT",
  725. float_val=data.reshape(-1).tolist(),
  726. tensor_shape=TensorShapeProto(
  727. dim=[
  728. # pyrefly: ignore [missing-attribute]
  729. TensorShapeProto.Dim(size=data.shape[0]),
  730. # pyrefly: ignore [missing-attribute]
  731. TensorShapeProto.Dim(size=data.shape[1]),
  732. ]
  733. ),
  734. )
  735. # pyrefly: ignore [missing-attribute]
  736. return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
  737. def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
  738. # weird, value > 127 breaks protobuf
  739. num_thresholds = min(num_thresholds, 127)
  740. data = compute_curve(
  741. labels, predictions, num_thresholds=num_thresholds, weights=weights
  742. )
  743. pr_curve_plugin_data = PrCurvePluginData(
  744. version=0, num_thresholds=num_thresholds
  745. ).SerializeToString()
  746. # pyrefly: ignore [missing-attribute]
  747. plugin_data = SummaryMetadata.PluginData(
  748. plugin_name="pr_curves", content=pr_curve_plugin_data
  749. )
  750. smd = SummaryMetadata(plugin_data=plugin_data)
  751. tensor = TensorProto(
  752. dtype="DT_FLOAT",
  753. float_val=data.reshape(-1).tolist(),
  754. tensor_shape=TensorShapeProto(
  755. dim=[
  756. # pyrefly: ignore [missing-attribute]
  757. TensorShapeProto.Dim(size=data.shape[0]),
  758. # pyrefly: ignore [missing-attribute]
  759. TensorShapeProto.Dim(size=data.shape[1]),
  760. ]
  761. ),
  762. )
  763. # pyrefly: ignore [missing-attribute]
  764. return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
  765. # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
  766. def compute_curve(labels, predictions, num_thresholds=None, weights=None):
  767. _MINIMUM_COUNT = 1e-7
  768. if weights is None:
  769. weights = 1.0
  770. # Compute bins of true positives and false positives.
  771. # pyrefly: ignore [unsupported-operation]
  772. bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
  773. float_labels = labels.astype(np.float64)
  774. # pyrefly: ignore [unsupported-operation]
  775. histogram_range = (0, num_thresholds - 1)
  776. tp_buckets, _ = np.histogram(
  777. bucket_indices,
  778. # pyrefly: ignore [bad-argument-type]
  779. bins=num_thresholds,
  780. range=histogram_range,
  781. weights=float_labels * weights,
  782. )
  783. fp_buckets, _ = np.histogram(
  784. bucket_indices,
  785. # pyrefly: ignore [bad-argument-type]
  786. bins=num_thresholds,
  787. range=histogram_range,
  788. weights=(1.0 - float_labels) * weights,
  789. )
  790. # Obtain the reverse cumulative sum.
  791. tp = np.cumsum(tp_buckets[::-1])[::-1]
  792. fp = np.cumsum(fp_buckets[::-1])[::-1]
  793. tn = fp[0] - fp
  794. fn = tp[0] - tp
  795. precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
  796. recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
  797. return np.stack((tp, fp, tn, fn, precision, recall))
  798. def _get_tensor_summary(
  799. name, display_name, description, tensor, content_type, components, json_config
  800. ):
  801. """Create a tensor summary with summary metadata.
  802. Args:
  803. name: Uniquely identifiable name of the summary op. Could be replaced by
  804. combination of name and type to make it unique even outside of this
  805. summary.
  806. display_name: Will be used as the display name in TensorBoard.
  807. Defaults to `name`.
  808. description: A longform readable description of the summary data. Markdown
  809. is supported.
  810. tensor: Tensor to display in summary.
  811. content_type: Type of content inside the Tensor.
  812. components: Bitmask representing present parts (vertices, colors, etc.) that
  813. belong to the summary.
  814. json_config: A string, JSON-serialized dictionary of ThreeJS classes
  815. configuration.
  816. Returns:
  817. Tensor summary with metadata.
  818. """
  819. import torch
  820. from tensorboard.plugins.mesh import metadata
  821. tensor = torch.as_tensor(tensor)
  822. tensor_metadata = metadata.create_summary_metadata(
  823. name,
  824. display_name,
  825. content_type,
  826. components,
  827. tensor.shape,
  828. description,
  829. json_config=json_config,
  830. )
  831. tensor = TensorProto(
  832. dtype="DT_FLOAT",
  833. float_val=tensor.reshape(-1).tolist(),
  834. tensor_shape=TensorShapeProto(
  835. dim=[
  836. # pyrefly: ignore [missing-attribute]
  837. TensorShapeProto.Dim(size=tensor.shape[0]),
  838. # pyrefly: ignore [missing-attribute]
  839. TensorShapeProto.Dim(size=tensor.shape[1]),
  840. # pyrefly: ignore [missing-attribute]
  841. TensorShapeProto.Dim(size=tensor.shape[2]),
  842. ]
  843. ),
  844. )
  845. # pyrefly: ignore [missing-attribute]
  846. tensor_summary = Summary.Value(
  847. tag=metadata.get_instance_name(name, content_type),
  848. tensor=tensor,
  849. metadata=tensor_metadata,
  850. )
  851. return tensor_summary
  852. def _get_json_config(config_dict):
  853. """Parse and returns JSON string from python dictionary."""
  854. json_config = "{}"
  855. if config_dict is not None:
  856. json_config = json.dumps(config_dict, sort_keys=True)
  857. return json_config
  858. # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
  859. def mesh(
  860. tag, vertices, colors, faces, config_dict, display_name=None, description=None
  861. ):
  862. """Output a merged `Summary` protocol buffer with a mesh/point cloud.
  863. Args:
  864. tag: A name for this summary operation.
  865. vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
  866. coordinates of vertices.
  867. faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
  868. vertices within each triangle.
  869. colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
  870. vertex.
  871. display_name: If set, will be used as the display name in TensorBoard.
  872. Defaults to `name`.
  873. description: A longform readable description of the summary data. Markdown
  874. is supported.
  875. config_dict: Dictionary with ThreeJS classes names and configuration.
  876. Returns:
  877. Merged summary for mesh/point cloud representation.
  878. """
  879. from tensorboard.plugins.mesh import metadata
  880. from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
  881. json_config = _get_json_config(config_dict)
  882. summaries = []
  883. tensors = [
  884. # pyrefly: ignore [missing-attribute]
  885. (vertices, MeshPluginData.VERTEX),
  886. # pyrefly: ignore [missing-attribute]
  887. (faces, MeshPluginData.FACE),
  888. # pyrefly: ignore [missing-attribute]
  889. (colors, MeshPluginData.COLOR),
  890. ]
  891. tensors = [tensor for tensor in tensors if tensor[0] is not None]
  892. components = metadata.get_components_bitmask(
  893. [content_type for (tensor, content_type) in tensors]
  894. )
  895. for tensor, content_type in tensors:
  896. summaries.append(
  897. _get_tensor_summary(
  898. tag,
  899. display_name,
  900. description,
  901. tensor,
  902. content_type,
  903. components,
  904. json_config,
  905. )
  906. )
  907. return Summary(value=summaries)