line_series.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from __future__ import annotations
  2. from collections.abc import Iterable
  3. from typing import TYPE_CHECKING, Any
  4. import wandb
  5. from wandb.plot.custom_chart import plot_table
  6. if TYPE_CHECKING:
  7. from wandb.plot.custom_chart import CustomChart
  8. def line_series(
  9. xs: Iterable[Iterable[Any]] | Iterable[Any],
  10. ys: Iterable[Iterable[Any]],
  11. keys: Iterable[str] | None = None,
  12. title: str = "",
  13. xname: str = "x",
  14. split_table: bool = False,
  15. ) -> CustomChart:
  16. """Constructs a line series chart.
  17. Args:
  18. xs: Sequence of x values. If a singular
  19. array is provided, all y values are plotted against that x array. If
  20. an array of arrays is provided, each y value is plotted against the
  21. corresponding x array.
  22. ys: Sequence of y values, where each iterable represents
  23. a separate line series.
  24. keys: Sequence of keys for labeling each line series. If
  25. not provided, keys will be automatically generated as "line_1",
  26. "line_2", etc.
  27. title: Title of the chart.
  28. xname: Label for the x-axis.
  29. split_table: Whether the table should be split into a separate section
  30. in the W&B UI. If `True`, the table will be displayed in a section named
  31. "Custom Chart Tables". Default is `False`.
  32. Returns:
  33. CustomChart: A custom chart object that can be logged to W&B. To log the
  34. chart, pass it to `wandb.log()`.
  35. Examples:
  36. Logging a single x array where all y series are plotted against the same x values:
  37. ```python
  38. import wandb
  39. # Initialize W&B run
  40. with wandb.init(project="line_series_example") as run:
  41. # x values shared across all y series
  42. xs = list(range(10))
  43. # Multiple y series to plot
  44. ys = [
  45. [i for i in range(10)], # y = x
  46. [i**2 for i in range(10)], # y = x^2
  47. [i**3 for i in range(10)], # y = x^3
  48. ]
  49. # Generate and log the line series chart
  50. line_series_chart = wandb.plot.line_series(
  51. xs,
  52. ys,
  53. title="title",
  54. xname="step",
  55. )
  56. run.log({"line-series-single-x": line_series_chart})
  57. ```
  58. In this example, a single `xs` series (shared x-values) is used for all
  59. `ys` series. This results in each y-series being plotted against the
  60. same x-values (0-9).
  61. Logging multiple x arrays where each y series is plotted against its corresponding x array:
  62. ```python
  63. import wandb
  64. # Initialize W&B run
  65. with wandb.init(project="line_series_example") as run:
  66. # Separate x values for each y series
  67. xs = [
  68. [i for i in range(10)], # x for first series
  69. [2 * i for i in range(10)], # x for second series (stretched)
  70. [3 * i for i in range(10)], # x for third series (stretched more)
  71. ]
  72. # Corresponding y series
  73. ys = [
  74. [i for i in range(10)], # y = x
  75. [i**2 for i in range(10)], # y = x^2
  76. [i**3 for i in range(10)], # y = x^3
  77. ]
  78. # Generate and log the line series chart
  79. line_series_chart = wandb.plot.line_series(
  80. xs, ys, title="Multiple X Arrays Example", xname="Step"
  81. )
  82. run.log({"line-series-multiple-x": line_series_chart})
  83. ```
  84. In this example, each y series is plotted against its own unique x series.
  85. This allows for more flexibility when the x values are not uniform across
  86. the data series.
  87. Customizing line labels using `keys`:
  88. ```python
  89. import wandb
  90. # Initialize W&B run
  91. with wandb.init(project="line_series_example") as run:
  92. xs = list(range(10)) # Single x array
  93. ys = [
  94. [i for i in range(10)], # y = x
  95. [i**2 for i in range(10)], # y = x^2
  96. [i**3 for i in range(10)], # y = x^3
  97. ]
  98. # Custom labels for each line
  99. keys = ["Linear", "Quadratic", "Cubic"]
  100. # Generate and log the line series chart
  101. line_series_chart = wandb.plot.line_series(
  102. xs,
  103. ys,
  104. keys=keys, # Custom keys (line labels)
  105. title="Custom Line Labels Example",
  106. xname="Step",
  107. )
  108. run.log({"line-series-custom-keys": line_series_chart})
  109. ```
  110. This example shows how to provide custom labels for the lines using
  111. the `keys` argument. The keys will appear in the legend as "Linear",
  112. "Quadratic", and "Cubic".
  113. """
  114. # If xs is a single array, repeat it for each y in ys
  115. if not isinstance(xs[0], Iterable) or isinstance(xs[0], (str, bytes)):
  116. xs = [xs] * len(ys)
  117. if len(xs) != len(ys):
  118. msg = f"Number of x-series ({len(xs)}) must match y-series ({len(ys)})."
  119. raise ValueError(msg)
  120. if keys is None:
  121. keys = [f"line_{i}" for i in range(len(ys))]
  122. if len(keys) != len(ys):
  123. msg = f"Number of keys ({len(keys)}) must match y-series ({len(ys)})."
  124. raise ValueError(msg)
  125. data = [
  126. [x, keys[i], y]
  127. for i, (xx, yy) in enumerate(zip(xs, ys))
  128. for x, y in zip(xx, yy)
  129. ]
  130. table = wandb.Table(
  131. data=data,
  132. columns=["step", "lineKey", "lineVal"],
  133. )
  134. return plot_table(
  135. data_table=table,
  136. vega_spec_name="wandb/lineseries/v0",
  137. fields={
  138. "step": "step",
  139. "lineKey": "lineKey",
  140. "lineVal": "lineVal",
  141. },
  142. string_fields={"title": title, "xname": xname},
  143. split_table=split_table,
  144. )