| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033 |
- # mypy: allow-untyped-defs
- import json
- import logging
- import struct
- from typing import Any
- import torch
- import numpy as np
- from google.protobuf import struct_pb2
- from tensorboard.compat.proto.summary_pb2 import (
- HistogramProto,
- Summary,
- SummaryMetadata,
- )
- from tensorboard.compat.proto.tensor_pb2 import TensorProto
- from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
- from tensorboard.plugins.custom_scalar import layout_pb2
- from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
- from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
- from ._convert_np import make_np
- from ._utils import _prepare_video, convert_to_HWC
- __all__ = [
- "half_to_int",
- "int_to_half",
- "hparams",
- "scalar",
- "histogram_raw",
- "histogram",
- "make_histogram",
- "image",
- "image_boxes",
- "draw_boxes",
- "make_image",
- "video",
- "make_video",
- "audio",
- "custom_scalars",
- "text",
- "tensor_proto",
- "pr_curve_raw",
- "pr_curve",
- "compute_curve",
- "mesh",
- ]
- logger = logging.getLogger(__name__)
- def half_to_int(f: float) -> int:
- """Casts a half-precision float value into an integer.
- Converts a half precision floating point value, such as `torch.half` or
- `torch.bfloat16`, into an integer value which can be written into the
- half_val field of a TensorProto for storage.
- To undo the effects of this conversion, use int_to_half().
- """
- buf = struct.pack("f", f)
- return struct.unpack("i", buf)[0]
- def int_to_half(i: int) -> float:
- """Casts an integer value to a half-precision float.
- Converts an integer value obtained from half_to_int back into a floating
- point value.
- """
- buf = struct.pack("i", i)
- return struct.unpack("f", buf)[0]
- def _tensor_to_half_val(t: torch.Tensor) -> list[int]:
- return [half_to_int(x) for x in t.flatten().tolist()]
- def _tensor_to_complex_val(t: torch.Tensor) -> list[float]:
- return torch.view_as_real(t).flatten().tolist()
- def _tensor_to_list(t: torch.Tensor) -> list[Any]:
- return t.flatten().tolist()
- # type maps: torch.Tensor type -> (protobuf type, protobuf val field)
- _TENSOR_TYPE_MAP = {
- torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
- torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
- torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
- torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
- torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
- torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
- torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
- torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
- torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
- torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
- torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
- torch.short: ("DT_INT16", "int_val", _tensor_to_list),
- torch.int: ("DT_INT32", "int_val", _tensor_to_list),
- torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
- torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
- torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
- torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
- torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
- torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
- torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
- torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
- torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
- torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
- torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
- torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
- torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
- }
- def _calc_scale_factor(tensor) -> int:
- converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
- return 1 if converted.dtype == np.uint8 else 255
- def _draw_single_box(
- image,
- xmin,
- ymin,
- xmax,
- ymax,
- display_str,
- color="black",
- color_text="black",
- thickness=2,
- ):
- from PIL import ImageDraw, ImageFont
- font = ImageFont.load_default()
- draw = ImageDraw.Draw(image)
- (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
- draw.line(
- [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
- width=thickness,
- fill=color,
- )
- if display_str:
- text_bottom = bottom
- # Reverse list and print from bottom to top.
- _left, _top, _right, _bottom = font.getbbox(display_str)
- text_width, text_height = _right - _left, _bottom - _top
- margin = np.ceil(0.05 * text_height)
- draw.rectangle(
- [
- (left, text_bottom - text_height - 2 * margin),
- (left + text_width, text_bottom),
- ],
- fill=color,
- )
- draw.text(
- (left + margin, text_bottom - text_height - margin),
- display_str,
- fill=color_text,
- font=font,
- )
- return image
- def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
- """Output three `Summary` protocol buffers needed by hparams plugin.
- `Experiment` keeps the metadata of an experiment, such as the name of the
- hyperparameters and the name of the metrics.
- `SessionStartInfo` keeps key-value pairs of the hyperparameters
- `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
- Args:
- hparam_dict: A dictionary that contains names of the hyperparameters
- and their values.
- metric_dict: A dictionary that contains names of the metrics
- and their values.
- hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
- contains names of the hyperparameters and all discrete values they can hold
- Returns:
- The `Summary` protobufs for Experiment, SessionStartInfo and
- SessionEndInfo
- """
- import torch
- from tensorboard.plugins.hparams.api_pb2 import (
- DataType,
- Experiment,
- HParamInfo,
- MetricInfo,
- MetricName,
- Status,
- )
- from tensorboard.plugins.hparams.metadata import (
- EXPERIMENT_TAG,
- PLUGIN_DATA_VERSION,
- PLUGIN_NAME,
- SESSION_END_INFO_TAG,
- SESSION_START_INFO_TAG,
- )
- from tensorboard.plugins.hparams.plugin_data_pb2 import (
- HParamsPluginData,
- SessionEndInfo,
- SessionStartInfo,
- )
- # TODO: expose other parameters in the future.
- # hp = HParamInfo(name='lr',display_name='learning rate',
- # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
- # max_value=100))
- # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
- # description='', dataset_type=DatasetType.DATASET_VALIDATION)
- # exp = Experiment(name='123', description='456', time_created_secs=100.0,
- # hparam_infos=[hp], metric_infos=[mt], user='tw')
- if not isinstance(hparam_dict, dict):
- logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
- raise TypeError(
- "parameter: hparam_dict should be a dictionary, nothing logged."
- )
- if not isinstance(metric_dict, dict):
- logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
- raise TypeError(
- "parameter: metric_dict should be a dictionary, nothing logged."
- )
- hparam_domain_discrete = hparam_domain_discrete or {}
- if not isinstance(hparam_domain_discrete, dict):
- raise TypeError(
- "parameter: hparam_domain_discrete should be a dictionary, nothing logged."
- )
- for k, v in hparam_domain_discrete.items():
- if (
- k not in hparam_dict
- or not isinstance(v, list)
- or not all(isinstance(d, type(hparam_dict[k])) for d in v)
- ):
- raise TypeError(
- f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]."
- )
- hps = []
- ssi = SessionStartInfo()
- for k, v in hparam_dict.items():
- if v is None:
- continue
- if isinstance(v, (int, float)):
- ssi.hparams[k].number_value = v
- if k in hparam_domain_discrete:
- domain_discrete: struct_pb2.ListValue | None = struct_pb2.ListValue(
- values=[
- struct_pb2.Value(number_value=d)
- for d in hparam_domain_discrete[k]
- ]
- )
- else:
- domain_discrete = None
- hps.append(
- HParamInfo(
- name=k,
- # pyrefly: ignore [missing-attribute]
- type=DataType.Value("DATA_TYPE_FLOAT64"),
- domain_discrete=domain_discrete,
- )
- )
- continue
- if isinstance(v, str):
- ssi.hparams[k].string_value = v
- if k in hparam_domain_discrete:
- domain_discrete = struct_pb2.ListValue(
- values=[
- struct_pb2.Value(string_value=d)
- for d in hparam_domain_discrete[k]
- ]
- )
- else:
- domain_discrete = None
- hps.append(
- HParamInfo(
- name=k,
- # pyrefly: ignore [missing-attribute]
- type=DataType.Value("DATA_TYPE_STRING"),
- domain_discrete=domain_discrete,
- )
- )
- continue
- if isinstance(v, bool):
- ssi.hparams[k].bool_value = v
- if k in hparam_domain_discrete:
- domain_discrete = struct_pb2.ListValue(
- values=[
- struct_pb2.Value(bool_value=d)
- for d in hparam_domain_discrete[k]
- ]
- )
- else:
- domain_discrete = None
- hps.append(
- HParamInfo(
- name=k,
- # pyrefly: ignore [missing-attribute]
- type=DataType.Value("DATA_TYPE_BOOL"),
- domain_discrete=domain_discrete,
- )
- )
- continue
- if isinstance(v, torch.Tensor):
- v = make_np(v)[0]
- ssi.hparams[k].number_value = v
- # pyrefly: ignore [missing-attribute]
- hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
- continue
- raise ValueError(
- "value should be one of int, float, str, bool, or torch.Tensor"
- )
- content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
- smd = SummaryMetadata(
- # pyrefly: ignore [missing-attribute]
- plugin_data=SummaryMetadata.PluginData(
- plugin_name=PLUGIN_NAME, content=content.SerializeToString()
- )
- )
- # pyrefly: ignore [missing-attribute]
- ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
- mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict]
- exp = Experiment(hparam_infos=hps, metric_infos=mts)
- content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
- smd = SummaryMetadata(
- # pyrefly: ignore [missing-attribute]
- plugin_data=SummaryMetadata.PluginData(
- plugin_name=PLUGIN_NAME, content=content.SerializeToString()
- )
- )
- # pyrefly: ignore [missing-attribute]
- exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
- # pyrefly: ignore [missing-attribute]
- sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
- content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
- smd = SummaryMetadata(
- # pyrefly: ignore [missing-attribute]
- plugin_data=SummaryMetadata.PluginData(
- plugin_name=PLUGIN_NAME, content=content.SerializeToString()
- )
- )
- # pyrefly: ignore [missing-attribute]
- sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
- return exp, ssi, sei
- def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
- """Output a `Summary` protocol buffer containing a single scalar value.
- The generated Summary has a Tensor.proto containing the input Tensor.
- Args:
- name: A name for the generated node. Will also serve as the series name in
- TensorBoard.
- tensor: A real numeric Tensor containing a single value.
- collections: Optional list of graph collections keys. The new summary op is
- added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
- new_style: Whether to use new style (tensor field) or old style (simple_value
- field). New style could lead to faster data loading.
- Returns:
- A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
- Raises:
- ValueError: If tensor has the wrong shape or type.
- """
- tensor = make_np(tensor).squeeze()
- if tensor.ndim != 0:
- raise AssertionError(f"Tensor should contain one element (0 dimensions). \
- Was given size: {tensor.size} and {tensor.ndim} dimensions.")
- # python float is double precision in numpy
- scalar = float(tensor)
- if new_style:
- tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
- if double_precision:
- tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
- # pyrefly: ignore [missing-attribute]
- plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
- smd = SummaryMetadata(plugin_data=plugin_data)
- return Summary(
- value=[
- # pyrefly: ignore [missing-attribute]
- Summary.Value(
- tag=name,
- tensor=tensor_proto,
- metadata=smd,
- )
- ]
- )
- else:
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
- def tensor_proto(tag, tensor):
- """Outputs a `Summary` protocol buffer containing the full tensor.
- The generated Summary has a Tensor.proto containing the input Tensor.
- Args:
- tag: A name for the generated node. Will also serve as the series name in
- TensorBoard.
- tensor: Tensor to be converted to protobuf
- Returns:
- A tensor protobuf in a `Summary` protobuf.
- Raises:
- ValueError: If tensor is too big to be converted to protobuf, or
- tensor data type is not supported
- """
- if tensor.numel() * tensor.itemsize >= (1 << 31):
- raise ValueError(
- "tensor is bigger than protocol buffer's hard limit of 2GB in size"
- )
- if tensor.dtype in _TENSOR_TYPE_MAP:
- dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
- tensor_proto = TensorProto(
- **{
- "dtype": dtype,
- "tensor_shape": TensorShapeProto(
- # pyrefly: ignore [missing-attribute]
- dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
- ),
- field_name: conversion_fn(tensor),
- },
- )
- else:
- raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")
- # pyrefly: ignore [missing-attribute]
- plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
- smd = SummaryMetadata(plugin_data=plugin_data)
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])
- def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
- # pylint: disable=line-too-long
- """Output a `Summary` protocol buffer with a histogram.
- The generated
- [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
- has one summary value containing a histogram for `values`.
- Args:
- name: A name for the generated node. Will also serve as a series name in
- TensorBoard.
- min: A float or int min value
- max: A float or int max value
- num: Int number of values
- sum: Float or int sum of all values
- sum_squares: Float or int sum of squares for all values
- bucket_limits: A numeric `Tensor` with upper value per bucket
- bucket_counts: A numeric `Tensor` with number of values per bucket
- Returns:
- A scalar `Tensor` of type `string`. The serialized `Summary` protocol
- buffer.
- """
- hist = HistogramProto(
- min=min,
- max=max,
- num=num,
- sum=sum,
- sum_squares=sum_squares,
- bucket_limit=bucket_limits,
- bucket=bucket_counts,
- )
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=name, histo=hist)])
- def histogram(name, values, bins, max_bins=None):
- # pylint: disable=line-too-long
- """Output a `Summary` protocol buffer with a histogram.
- The generated
- [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
- has one summary value containing a histogram for `values`.
- This op reports an `InvalidArgument` error if any value is not finite.
- Args:
- name: A name for the generated node. Will also serve as a series name in
- TensorBoard.
- values: A real numeric `Tensor`. Any shape. Values to use to
- build the histogram.
- Returns:
- A scalar `Tensor` of type `string`. The serialized `Summary` protocol
- buffer.
- """
- values = make_np(values)
- hist = make_histogram(values.astype(float), bins, max_bins)
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=name, histo=hist)])
- def make_histogram(values, bins, max_bins=None):
- """Convert values into a histogram proto using logic from histogram.cc."""
- if values.size == 0:
- raise ValueError("The input has no element.")
- values = values.reshape(-1)
- counts, limits = np.histogram(values, bins=bins)
- num_bins = len(counts)
- if max_bins is not None and num_bins > max_bins:
- subsampling = num_bins // max_bins
- subsampling_remainder = num_bins % subsampling
- if subsampling_remainder != 0:
- # pyrefly: ignore [no-matching-overload]
- counts = np.pad(
- counts,
- pad_width=[[0, subsampling - subsampling_remainder]],
- mode="constant",
- constant_values=0,
- )
- counts = counts.reshape(-1, subsampling).sum(axis=-1)
- new_limits = np.empty((counts.size + 1,), limits.dtype)
- new_limits[:-1] = limits[:-1:subsampling]
- new_limits[-1] = limits[-1]
- limits = new_limits
- # Find the first and the last bin defining the support of the histogram:
- cum_counts = np.cumsum(np.greater(counts, 0))
- start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
- start = int(start)
- end = int(end) + 1
- del cum_counts
- # TensorBoard only includes the right bin limits. To still have the leftmost limit
- # included, we include an empty bin left.
- # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
- # first nonzero-count bin:
- counts = (
- counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
- )
- limits = limits[start : end + 1]
- if counts.size == 0 or limits.size == 0:
- raise ValueError("The histogram is empty, please file a bug report.")
- sum_sq = values.dot(values)
- return HistogramProto(
- min=values.min(),
- max=values.max(),
- num=len(values),
- sum=values.sum(),
- sum_squares=sum_sq,
- bucket_limit=limits.tolist(),
- bucket=counts.tolist(),
- )
- def image(tag, tensor, rescale=1, dataformats="NCHW"):
- """Output a `Summary` protocol buffer with images.
- The summary has up to `max_images` summary values containing images. The
- images are built from `tensor` which must be 3-D with shape `[height, width,
- channels]` and where `channels` can be:
- * 1: `tensor` is interpreted as Grayscale.
- * 3: `tensor` is interpreted as RGB.
- * 4: `tensor` is interpreted as RGBA.
- The `name` in the outputted Summary.Value protobufs is generated based on the
- name, with a suffix depending on the max_outputs setting:
- * If `max_outputs` is 1, the summary value tag is '*name*/image'.
- * If `max_outputs` is greater than 1, the summary value tags are
- generated sequentially as '*name*/image/0', '*name*/image/1', etc.
- Args:
- tag: A name for the generated node. Will also serve as a series name in
- TensorBoard.
- tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
- channels]` where `channels` is 1, 3, or 4.
- 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
- The image() function will scale the image values to [0, 255] by applying
- a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
- will be clipped.
- Returns:
- A scalar `Tensor` of type `string`. The serialized `Summary` protocol
- buffer.
- """
- tensor = make_np(tensor)
- tensor = convert_to_HWC(tensor, dataformats)
- # Do not assume that user passes in values in [0, 255], use data type to detect
- scale_factor = _calc_scale_factor(tensor)
- tensor = tensor.astype(np.float32)
- tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
- image = make_image(tensor, rescale=rescale)
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=tag, image=image)])
- def image_boxes(
- tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
- ):
- """Output a `Summary` protocol buffer with images."""
- tensor_image = make_np(tensor_image)
- tensor_image = convert_to_HWC(tensor_image, dataformats)
- tensor_boxes = make_np(tensor_boxes)
- tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
- image = make_image(
- tensor_image.clip(0, 255).astype(np.uint8),
- rescale=rescale,
- rois=tensor_boxes,
- labels=labels,
- )
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=tag, image=image)])
- def draw_boxes(disp_image, boxes, labels=None):
- # xyxy format
- num_boxes = boxes.shape[0]
- list_gt = range(num_boxes)
- for i in list_gt:
- disp_image = _draw_single_box(
- disp_image,
- boxes[i, 0],
- boxes[i, 1],
- boxes[i, 2],
- boxes[i, 3],
- display_str=None if labels is None else labels[i],
- color="Red",
- )
- return disp_image
- def make_image(tensor, rescale=1, rois=None, labels=None):
- """Convert a numpy representation of an image to Image protobuf."""
- from PIL import Image
- height, width, channel = tensor.shape
- scaled_height = int(height * rescale)
- scaled_width = int(width * rescale)
- image = Image.fromarray(tensor)
- if rois is not None:
- image = draw_boxes(image, rois, labels=labels)
- ANTIALIAS = Image.Resampling.LANCZOS
- image = image.resize((scaled_width, scaled_height), ANTIALIAS)
- import io
- output = io.BytesIO()
- image.save(output, format="PNG")
- image_string = output.getvalue()
- output.close()
- # pyrefly: ignore [missing-attribute]
- return Summary.Image(
- height=height,
- width=width,
- colorspace=channel,
- encoded_image_string=image_string,
- )
- def video(tag, tensor, fps=4):
- tensor = make_np(tensor)
- tensor = _prepare_video(tensor)
- # If user passes in uint8, then we don't need to rescale by 255
- scale_factor = _calc_scale_factor(tensor)
- tensor = tensor.astype(np.float32)
- tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
- video = make_video(tensor, fps)
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=tag, image=video)])
- def make_video(tensor, fps):
- try:
- import moviepy # noqa: F401
- except ImportError:
- print("add_video needs package moviepy")
- return
- try:
- from moviepy import editor as mpy
- except ImportError:
- print(
- "moviepy is installed, but can't import moviepy.editor.",
- "Some packages could be missing [imageio, requests]",
- )
- return
- import tempfile
- _t, h, w, c = tensor.shape
- # encode sequence of images into gif string
- clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
- with tempfile.NamedTemporaryFile(suffix=".gif") as f:
- filename = f.name
- try: # newer version of moviepy use logger instead of progress_bar argument.
- clip.write_gif(filename, verbose=False, logger=None)
- except TypeError:
- try: # older version of moviepy does not support progress_bar argument.
- clip.write_gif(filename, verbose=False, progress_bar=False)
- except TypeError:
- clip.write_gif(filename, verbose=False)
- f.seek(0)
- tensor_string = f.read()
- # pyrefly: ignore [missing-attribute]
- return Summary.Image(
- height=h, width=w, colorspace=c, encoded_image_string=tensor_string
- )
- def audio(tag, tensor, sample_rate=44100):
- array = make_np(tensor)
- array = array.squeeze()
- if abs(array).max() > 1:
- print("warning: audio amplitude out of range, auto clipped.")
- array = array.clip(-1, 1)
- if array.ndim != 1:
- raise AssertionError("input tensor should be 1 dimensional.")
- # pyrefly: ignore [no-matching-overload]
- array = (array * np.iinfo(np.int16).max).astype("<i2")
- import io
- import wave
- fio = io.BytesIO()
- with wave.open(fio, "wb") as wave_write:
- wave_write.setnchannels(1)
- wave_write.setsampwidth(2)
- wave_write.setframerate(sample_rate)
- wave_write.writeframes(array.data)
- audio_string = fio.getvalue()
- fio.close()
- # pyrefly: ignore [missing-attribute]
- audio = Summary.Audio(
- sample_rate=sample_rate,
- num_channels=1,
- length_frames=array.shape[-1],
- encoded_audio_string=audio_string,
- content_type="audio/wav",
- )
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=tag, audio=audio)])
- def custom_scalars(layout):
- categories = []
- for k, v in layout.items():
- charts = []
- for chart_name, chart_metadata in v.items():
- tags = chart_metadata[1]
- if chart_metadata[0] == "Margin":
- if len(tags) != 3:
- raise AssertionError("len(tags) != 3")
- mgcc = layout_pb2.MarginChartContent(
- series=[
- # pyrefly: ignore [missing-attribute]
- layout_pb2.MarginChartContent.Series(
- value=tags[0], lower=tags[1], upper=tags[2]
- )
- ]
- )
- chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
- else:
- mlcc = layout_pb2.MultilineChartContent(tag=tags)
- chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
- charts.append(chart)
- categories.append(layout_pb2.Category(title=k, chart=charts))
- layout = layout_pb2.Layout(category=categories)
- # pyrefly: ignore [missing-attribute]
- plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
- smd = SummaryMetadata(plugin_data=plugin_data)
- tensor = TensorProto(
- dtype="DT_STRING",
- string_val=[layout.SerializeToString()],
- tensor_shape=TensorShapeProto(),
- )
- return Summary(
- value=[
- # pyrefly: ignore [missing-attribute]
- Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
- ]
- )
- def text(tag, text):
- # pyrefly: ignore [missing-attribute]
- plugin_data = SummaryMetadata.PluginData(
- plugin_name="text", content=TextPluginData(version=0).SerializeToString()
- )
- smd = SummaryMetadata(plugin_data=plugin_data)
- tensor = TensorProto(
- dtype="DT_STRING",
- string_val=[text.encode(encoding="utf_8")],
- # pyrefly: ignore [missing-attribute]
- tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
- )
- return Summary(
- # pyrefly: ignore [missing-attribute]
- value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
- )
- def pr_curve_raw(
- tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
- ):
- if num_thresholds > 127: # weird, value > 127 breaks protobuf
- num_thresholds = 127
- data = np.stack((tp, fp, tn, fn, precision, recall))
- pr_curve_plugin_data = PrCurvePluginData(
- version=0, num_thresholds=num_thresholds
- ).SerializeToString()
- # pyrefly: ignore [missing-attribute]
- plugin_data = SummaryMetadata.PluginData(
- plugin_name="pr_curves", content=pr_curve_plugin_data
- )
- smd = SummaryMetadata(plugin_data=plugin_data)
- tensor = TensorProto(
- dtype="DT_FLOAT",
- float_val=data.reshape(-1).tolist(),
- tensor_shape=TensorShapeProto(
- dim=[
- # pyrefly: ignore [missing-attribute]
- TensorShapeProto.Dim(size=data.shape[0]),
- # pyrefly: ignore [missing-attribute]
- TensorShapeProto.Dim(size=data.shape[1]),
- ]
- ),
- )
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
- def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
- # weird, value > 127 breaks protobuf
- num_thresholds = min(num_thresholds, 127)
- data = compute_curve(
- labels, predictions, num_thresholds=num_thresholds, weights=weights
- )
- pr_curve_plugin_data = PrCurvePluginData(
- version=0, num_thresholds=num_thresholds
- ).SerializeToString()
- # pyrefly: ignore [missing-attribute]
- plugin_data = SummaryMetadata.PluginData(
- plugin_name="pr_curves", content=pr_curve_plugin_data
- )
- smd = SummaryMetadata(plugin_data=plugin_data)
- tensor = TensorProto(
- dtype="DT_FLOAT",
- float_val=data.reshape(-1).tolist(),
- tensor_shape=TensorShapeProto(
- dim=[
- # pyrefly: ignore [missing-attribute]
- TensorShapeProto.Dim(size=data.shape[0]),
- # pyrefly: ignore [missing-attribute]
- TensorShapeProto.Dim(size=data.shape[1]),
- ]
- ),
- )
- # pyrefly: ignore [missing-attribute]
- return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
- # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
- def compute_curve(labels, predictions, num_thresholds=None, weights=None):
- _MINIMUM_COUNT = 1e-7
- if weights is None:
- weights = 1.0
- # Compute bins of true positives and false positives.
- # pyrefly: ignore [unsupported-operation]
- bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
- float_labels = labels.astype(np.float64)
- # pyrefly: ignore [unsupported-operation]
- histogram_range = (0, num_thresholds - 1)
- tp_buckets, _ = np.histogram(
- bucket_indices,
- # pyrefly: ignore [bad-argument-type]
- bins=num_thresholds,
- range=histogram_range,
- weights=float_labels * weights,
- )
- fp_buckets, _ = np.histogram(
- bucket_indices,
- # pyrefly: ignore [bad-argument-type]
- bins=num_thresholds,
- range=histogram_range,
- weights=(1.0 - float_labels) * weights,
- )
- # Obtain the reverse cumulative sum.
- tp = np.cumsum(tp_buckets[::-1])[::-1]
- fp = np.cumsum(fp_buckets[::-1])[::-1]
- tn = fp[0] - fp
- fn = tp[0] - tp
- precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
- recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
- return np.stack((tp, fp, tn, fn, precision, recall))
- def _get_tensor_summary(
- name, display_name, description, tensor, content_type, components, json_config
- ):
- """Create a tensor summary with summary metadata.
- Args:
- name: Uniquely identifiable name of the summary op. Could be replaced by
- combination of name and type to make it unique even outside of this
- summary.
- display_name: Will be used as the display name in TensorBoard.
- Defaults to `name`.
- description: A longform readable description of the summary data. Markdown
- is supported.
- tensor: Tensor to display in summary.
- content_type: Type of content inside the Tensor.
- components: Bitmask representing present parts (vertices, colors, etc.) that
- belong to the summary.
- json_config: A string, JSON-serialized dictionary of ThreeJS classes
- configuration.
- Returns:
- Tensor summary with metadata.
- """
- import torch
- from tensorboard.plugins.mesh import metadata
- tensor = torch.as_tensor(tensor)
- tensor_metadata = metadata.create_summary_metadata(
- name,
- display_name,
- content_type,
- components,
- tensor.shape,
- description,
- json_config=json_config,
- )
- tensor = TensorProto(
- dtype="DT_FLOAT",
- float_val=tensor.reshape(-1).tolist(),
- tensor_shape=TensorShapeProto(
- dim=[
- # pyrefly: ignore [missing-attribute]
- TensorShapeProto.Dim(size=tensor.shape[0]),
- # pyrefly: ignore [missing-attribute]
- TensorShapeProto.Dim(size=tensor.shape[1]),
- # pyrefly: ignore [missing-attribute]
- TensorShapeProto.Dim(size=tensor.shape[2]),
- ]
- ),
- )
- # pyrefly: ignore [missing-attribute]
- tensor_summary = Summary.Value(
- tag=metadata.get_instance_name(name, content_type),
- tensor=tensor,
- metadata=tensor_metadata,
- )
- return tensor_summary
- def _get_json_config(config_dict):
- """Parse and returns JSON string from python dictionary."""
- json_config = "{}"
- if config_dict is not None:
- json_config = json.dumps(config_dict, sort_keys=True)
- return json_config
- # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
- def mesh(
- tag, vertices, colors, faces, config_dict, display_name=None, description=None
- ):
- """Output a merged `Summary` protocol buffer with a mesh/point cloud.
- Args:
- tag: A name for this summary operation.
- vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
- coordinates of vertices.
- faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
- vertices within each triangle.
- colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
- vertex.
- display_name: If set, will be used as the display name in TensorBoard.
- Defaults to `name`.
- description: A longform readable description of the summary data. Markdown
- is supported.
- config_dict: Dictionary with ThreeJS classes names and configuration.
- Returns:
- Merged summary for mesh/point cloud representation.
- """
- from tensorboard.plugins.mesh import metadata
- from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
- json_config = _get_json_config(config_dict)
- summaries = []
- tensors = [
- # pyrefly: ignore [missing-attribute]
- (vertices, MeshPluginData.VERTEX),
- # pyrefly: ignore [missing-attribute]
- (faces, MeshPluginData.FACE),
- # pyrefly: ignore [missing-attribute]
- (colors, MeshPluginData.COLOR),
- ]
- tensors = [tensor for tensor in tensors if tensor[0] is not None]
- components = metadata.get_components_bitmask(
- [content_type for (tensor, content_type) in tensors]
- )
- for tensor, content_type in tensors:
- summaries.append(
- _get_tensor_summary(
- tag,
- display_name,
- description,
- tensor,
- content_type,
- components,
- json_config,
- )
- )
- return Summary(value=summaries)
|