custom_chart.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import Any
  4. import wandb
  5. @dataclass
  6. class CustomChartSpec:
  7. spec_name: str
  8. fields: dict[str, Any]
  9. string_fields: dict[str, Any]
  10. key: str = ""
  11. panel_type: str = "Vega2"
  12. split_table: bool = False
  13. @property
  14. def table_key(self) -> str:
  15. if not self.key:
  16. raise wandb.Error("Key for the custom chart spec is not set.")
  17. if self.split_table:
  18. return f"Custom Chart Tables/{self.key}_table"
  19. return f"{self.key}_table"
  20. @property
  21. def config_value(self) -> dict[str, Any]:
  22. return {
  23. "panel_type": self.panel_type,
  24. "panel_config": {
  25. "panelDefId": self.spec_name,
  26. "fieldSettings": self.fields,
  27. "stringSettings": self.string_fields,
  28. "transform": {"name": "tableWithLeafColNames"},
  29. "userQuery": {
  30. "queryFields": [
  31. {
  32. "name": "runSets",
  33. "args": [{"name": "runSets", "value": "${runSets}"}],
  34. "fields": [
  35. {"name": "id", "fields": []},
  36. {"name": "name", "fields": []},
  37. {"name": "_defaultColorIndex", "fields": []},
  38. {
  39. "name": "summaryTable",
  40. "args": [
  41. {
  42. "name": "tableKey",
  43. "value": self.table_key,
  44. }
  45. ],
  46. "fields": [],
  47. },
  48. ],
  49. }
  50. ],
  51. },
  52. },
  53. }
  54. @property
  55. def config_key(self) -> tuple[str, str, str]:
  56. return ("_wandb", "visualize", self.key)
  57. @dataclass
  58. class CustomChart:
  59. table: wandb.Table
  60. spec: CustomChartSpec
  61. def set_key(self, key: str):
  62. """Sets the key for the spec and updates dependent configurations."""
  63. self.spec.key = key
  64. def plot_table(
  65. vega_spec_name: str,
  66. data_table: wandb.Table,
  67. fields: dict[str, Any],
  68. string_fields: dict[str, Any] | None = None,
  69. split_table: bool = False,
  70. ) -> CustomChart:
  71. """Creates a custom charts using a Vega-Lite specification and a `wandb.Table`.
  72. This function creates a custom chart based on a Vega-Lite specification and
  73. a data table represented by a `wandb.Table` object. The specification needs
  74. to be predefined and stored in the W&B backend. The function returns a custom
  75. chart object that can be logged to W&B using `wandb.Run.log()`.
  76. Args:
  77. vega_spec_name: The name or identifier of the Vega-Lite spec
  78. that defines the visualization structure.
  79. data_table: A `wandb.Table` object containing the data to be
  80. visualized.
  81. fields: A mapping between the fields in the Vega-Lite spec and the
  82. corresponding columns in the data table to be visualized.
  83. string_fields: A dictionary for providing values for any string constants
  84. required by the custom visualization.
  85. split_table: Whether the table should be split into a separate section
  86. in the W&B UI. If `True`, the table will be displayed in a section named
  87. "Custom Chart Tables". Default is `False`.
  88. Returns:
  89. CustomChart: A custom chart object that can be logged to W&B. To log the
  90. chart, pass the chart object as argument to `wandb.Run.log()`.
  91. Raises:
  92. wandb.Error: If `data_table` is not a `wandb.Table` object.
  93. Example:
  94. ```python
  95. # Create a custom chart using a Vega-Lite spec and the data table.
  96. import wandb
  97. data = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
  98. table = wandb.Table(data=data, columns=["x", "y"])
  99. fields = {"x": "x", "y": "y", "title": "MY TITLE"}
  100. with wandb.init() as run:
  101. # Training code goes here
  102. # Create a custom title with `string_fields`.
  103. my_custom_chart = wandb.plot_table(
  104. vega_spec_name="wandb/line/v0",
  105. data_table=table,
  106. fields=fields,
  107. string_fields={"title": "Title"},
  108. )
  109. run.log({"custom_chart": my_custom_chart})
  110. ```
  111. """
  112. if not isinstance(data_table, wandb.Table):
  113. raise wandb.Error(
  114. f"Expected `data_table` to be `wandb.Table` type, instead got {type(data_table).__name__}"
  115. )
  116. return CustomChart(
  117. table=data_table,
  118. spec=CustomChartSpec(
  119. spec_name=vega_spec_name,
  120. fields=fields,
  121. string_fields=string_fields or {},
  122. split_table=split_table,
  123. ),
  124. )