dspy.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. """DSPy ↔ Weights & Biases integration."""
  2. from __future__ import annotations
  3. import logging
  4. import os
  5. from collections.abc import Mapping, Sequence
  6. from typing import Any, Literal
  7. import wandb
  8. import wandb.util
  9. from wandb.sdk.lib import telemetry
  10. from wandb.sdk.wandb_run import Run
  11. dspy = wandb.util.get_module(
  12. name="dspy",
  13. required=(
  14. "To use the W&B DSPy integration you need to have the `dspy` "
  15. "python package installed. Install it with `uv pip install dspy`."
  16. ),
  17. lazy=False,
  18. )
  19. if dspy is not None:
  20. assert dspy.__version__ >= "3.0.0", (
  21. "DSPy 3.0.0 or higher is required. You have " + dspy.__version__
  22. )
  23. logger = logging.getLogger(__name__)
  24. def _flatten_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
  25. """Flatten a list of nested row dicts into flat key/value dicts.
  26. Args:
  27. rows (list[dict[str, Any]]): List of nested dictionaries to flatten.
  28. Returns:
  29. list[dict[str, Any]]: List of flattened dictionaries.
  30. """
  31. def _flatten(
  32. d: dict[str, Any], parent_key: str = "", sep: str = "."
  33. ) -> dict[str, Any]:
  34. items = []
  35. for k, v in d.items():
  36. new_key = f"{parent_key}{sep}{k}" if parent_key else k
  37. if isinstance(v, dict):
  38. items.extend(_flatten(v, new_key, sep=sep).items())
  39. else:
  40. items.append((new_key, v))
  41. return dict(items)
  42. return [_flatten(row) for row in rows]
  43. class WandbDSPyCallback(dspy.utils.BaseCallback):
  44. """W&B callback for tracking DSPy evaluation and optimization.
  45. This callback logs evaluation scores, per-step predictions (optional), and
  46. a table capturing the DSPy program signature over time. It can also save
  47. the best program as a W&B Artifact for reproducibility.
  48. Examples:
  49. Basic usage within DSPy settings:
  50. ```python
  51. import dspy
  52. import wandb
  53. from wandb.integration.dspy import WandbDSPyCallback
  54. with wandb.init(project="dspy-optimization") as run:
  55. dspy.settings.callbacks.append(WandbDSPyCallback(run=run))
  56. # Run your DSPy optimization/evaluation
  57. ```
  58. """
  59. def __init__(self, log_results: bool = True, run: Run | None = None) -> None:
  60. """Initialize the callback.
  61. Args:
  62. log_results (bool): Whether to log per-evaluation prediction tables.
  63. run (Run | None): Optional W&B run to use. Defaults to the
  64. current global run if available.
  65. Raises:
  66. wandb.Error: If no active run is provided or found.
  67. """
  68. # If no run is provided, use the current global run if available.
  69. if run is None:
  70. if wandb.run is None:
  71. raise wandb.Error(
  72. "You must call `wandb.init()` before instantiating WandbDSPyCallback()."
  73. )
  74. run = wandb.run
  75. self.log_results = log_results
  76. with telemetry.context(run=run) as tel:
  77. tel.feature.dspy_callback = True
  78. self._run = run
  79. self._did_log_config: bool = False
  80. self._program_info: dict[str, Any] = {}
  81. self._program_table: wandb.Table | None = None
  82. self._row_idx: int = 0
  83. def _flatten_dict(
  84. self, nested: Any, parent_key: str = "", sep: str = "."
  85. ) -> dict[str, Any]:
  86. """Recursively flatten arbitrarily nested mappings and sequences.
  87. Args:
  88. nested (Any): Nested structure of mappings/lists to flatten.
  89. parent_key (str): Prefix to prepend to keys in the flattened output.
  90. sep (str): Key separator for nested fields.
  91. Returns:
  92. dict[str, Any]: Flattened dictionary representation.
  93. """
  94. flat: dict[str, Any] = {}
  95. def _walk(obj: Any, base: str) -> None:
  96. if isinstance(obj, Mapping):
  97. for k, v in obj.items():
  98. new_key = f"{base}{sep}{k}" if base else str(k)
  99. _walk(v, new_key)
  100. elif isinstance(obj, Sequence) and not isinstance(
  101. obj, (str, bytes, bytearray)
  102. ):
  103. for idx, v in enumerate(obj):
  104. new_key = f"{base}{sep}{idx}" if base else str(idx)
  105. _walk(v, new_key)
  106. else:
  107. # Base can be empty only if the top-level is a scalar; guard against that.
  108. key = base if base else ""
  109. if key:
  110. flat[key] = obj
  111. _walk(nested, parent_key)
  112. return flat
  113. def _extract_fields(self, fields: list[dict[str, Any]]) -> dict[str, str]:
  114. """Convert signature fields to a flat mapping of strings.
  115. Note:
  116. The input is expected to be a dict-like mapping from field names to
  117. field metadata. Values are stringified for logging.
  118. Args:
  119. fields (list[dict[str, Any]]): Mapping of field name to metadata.
  120. Returns:
  121. dict[str, str]: Mapping of field name to string value.
  122. """
  123. return {k: str(v) for k, v in fields.items()}
  124. def _extract_program_info(self, program_obj: Any) -> dict[str, Any]:
  125. """Extract signature-related info from a DSPy program.
  126. Attempts to read the program signature, instructions, input and output
  127. fields from a DSPy `Predict` parameter if available.
  128. Args:
  129. program_obj (Any): DSPy program/module instance.
  130. Returns:
  131. dict[str, Any]: Flattened dictionary of signature metadata.
  132. """
  133. info_dict = {}
  134. if program_obj is None:
  135. return info_dict
  136. try:
  137. sig = next(
  138. param.signature
  139. for _, param in program_obj.named_parameters()
  140. if isinstance(param, dspy.Predict)
  141. )
  142. if getattr(sig, "signature", None):
  143. info_dict["signature"] = sig.signature
  144. if getattr(sig, "instructions", None):
  145. info_dict["instructions"] = sig.instructions
  146. if getattr(sig, "input_fields", None):
  147. input_fields = sig.input_fields
  148. info_dict["input_fields"] = self._extract_fields(input_fields)
  149. if getattr(sig, "output_fields", None):
  150. output_fields = sig.output_fields
  151. info_dict["output_fields"] = self._extract_fields(output_fields)
  152. return self._flatten_dict(info_dict)
  153. except Exception as e:
  154. logger.warning(
  155. "Failed to extract program info from Evaluate instance: %s", e
  156. )
  157. return info_dict
  158. def on_evaluate_start(
  159. self,
  160. call_id: str,
  161. instance: Any,
  162. inputs: dict[str, Any],
  163. ) -> None:
  164. """Handle start of a DSPy evaluation call.
  165. Logs non-private fields from the evaluator instance to W&B config and
  166. captures program signature info for later logging.
  167. Args:
  168. call_id (str): Unique identifier for the evaluation call.
  169. instance (Any): The evaluation instance (e.g., `dspy.Evaluate`).
  170. inputs (dict[str, Any]): Inputs passed to the evaluation (may
  171. include a `program` key with the DSPy program).
  172. """
  173. if not self._did_log_config:
  174. instance_vars = vars(instance) if hasattr(instance, "__dict__") else {}
  175. serializable = {
  176. k: v for k, v in instance_vars.items() if not k.startswith("_")
  177. }
  178. if "devset" in serializable:
  179. # we don't want to log the devset in the config
  180. del serializable["devset"]
  181. self._run.config.update(serializable)
  182. self._did_log_config = True
  183. # 2) Build/append program signature tables from the 'program' inputs
  184. if program_obj := inputs.get("program"):
  185. self._program_info = self._extract_program_info(program_obj)
  186. def on_evaluate_end(
  187. self,
  188. call_id: str,
  189. outputs: Any | None,
  190. exception: Exception | None = None,
  191. ) -> None:
  192. """Handle end of a DSPy evaluation call.
  193. If available, logs a numeric `score` metric and (optionally) per-step
  194. prediction tables. Always appends a row to the program-signature table.
  195. Args:
  196. call_id (str): Unique identifier for the evaluation call.
  197. outputs (Any | None): Evaluation outputs; supports
  198. `dspy.evaluate.evaluate.EvaluationResult`.
  199. exception (Exception | None): Exception raised during evaluation, if any.
  200. """
  201. # The `BaseCallback` does not define the interface for the `outputs` parameter,
  202. # Currently, we know of `EvaluationResult` which is a subclass of `dspy.Prediction`.
  203. # We currently support this type and will warn the user if a different type is passed.
  204. score: float | None = None
  205. if exception is None:
  206. if isinstance(outputs, dspy.evaluate.evaluate.EvaluationResult):
  207. # log the float score as a wandb metric
  208. score = outputs.score
  209. wandb.log({"score": float(score)}, step=self._row_idx)
  210. # Log the predictions as a separate table for each eval end.
  211. # We know that results if of type `list[tuple["dspy.Example", "dspy.Example", Any]]`
  212. results = outputs.results
  213. if self.log_results:
  214. rows = self._parse_results(results)
  215. if rows:
  216. self._log_predictions_table(rows)
  217. else:
  218. wandb.termwarn(
  219. f"on_evaluate_end received unexpected outputs type: {type(outputs)}. "
  220. "Expected dspy.evaluate.evaluate.EvaluationResult; skipping logging score and `log_results`."
  221. )
  222. else:
  223. wandb.termwarn(
  224. f"on_evaluate_end received exception: {exception}. "
  225. "Skipping logging score and `log_results`."
  226. )
  227. # Log the program signature iteratively
  228. if self._program_table is None:
  229. columns = ["step", *self._program_info.keys()]
  230. if isinstance(score, float):
  231. columns.append("score")
  232. self._program_table = wandb.Table(columns=columns, log_mode="INCREMENTAL")
  233. if self._program_table is not None:
  234. values = list(self._program_info.values())
  235. if isinstance(score, float):
  236. values.append(score)
  237. self._program_table.add_data(
  238. self._row_idx,
  239. *values,
  240. )
  241. self._run.log(
  242. {"program_signature": self._program_table}, step=self._row_idx
  243. )
  244. self._row_idx += 1
  245. def _parse_results(
  246. self,
  247. results: list[tuple[dspy.Example, dspy.Prediction | dspy.Completions, bool]],
  248. ) -> list[dict[str, Any]]:
  249. """Normalize evaluation results into serializable row dicts.
  250. Args:
  251. results (list[tuple]): Sequence of `(example, prediction, is_correct)`
  252. tuples from DSPy evaluation.
  253. Returns:
  254. list[dict[str, Any]]: Rows with `example`, `prediction`, `is_correct`.
  255. """
  256. _rows: list[dict[str, Any]] = []
  257. for example, prediction, is_correct in results:
  258. if isinstance(prediction, dspy.Prediction):
  259. prediction_dict = prediction.toDict()
  260. if isinstance(prediction, dspy.Completions):
  261. prediction_dict = prediction.items()
  262. row: dict[str, Any] = {
  263. "example": example.toDict(),
  264. "prediction": prediction_dict,
  265. "is_correct": is_correct,
  266. }
  267. _rows.append(row)
  268. return _rows
  269. def _log_predictions_table(self, rows: list[dict[str, Any]]) -> None:
  270. """Log a W&B Table of predictions for the current evaluation step.
  271. Args:
  272. rows (list[dict[str, Any]]): Prediction rows to log.
  273. """
  274. rows = _flatten_rows(rows)
  275. columns = list(rows[0].keys())
  276. data: list[list[Any]] = [list(row.values()) for row in rows]
  277. preds_table = wandb.Table(columns=columns, data=data, log_mode="IMMUTABLE")
  278. self._run.log({f"predictions_{self._row_idx}": preds_table}, step=self._row_idx)
  279. def log_best_model(
  280. self,
  281. model: dspy.Module,
  282. *,
  283. save_program: bool = True,
  284. save_dir: str | None = None,
  285. filetype: Literal["json", "pkl"] = "json",
  286. aliases: Sequence[str] = ("best", "latest"),
  287. artifact_name: str = "dspy-program",
  288. ) -> None:
  289. """Save and log the best DSPy program as a W&B Artifact.
  290. You can choose to save the full program (architecture + state) or only
  291. the state to a single file (JSON or pickle).
  292. Args:
  293. model (dspy.Module): DSPy module to save.
  294. save_program (bool): Save full program directory if True; otherwise
  295. save only the state file. Defaults to `True`.
  296. save_dir (str): Directory to store program files before logging. Defaults to a
  297. subdirectory `dspy_program` within the active run's files directory
  298. (i.e., `wandb.run.dir`).
  299. filetype (Literal["json", "pkl"]): State file format when
  300. `save_program` is False. Defaults to `json`.
  301. aliases (Sequence[str]): Aliases for the logged Artifact version. Defaults to `("best", "latest")`.
  302. artifact_name (str): Base name for the Artifact. Defaults to `dspy-program`.
  303. Examples:
  304. Save the complete program and add aliases:
  305. ```python
  306. callback.log_best_model(
  307. optimized_program, save_program=True, aliases=("best", "production")
  308. )
  309. ```
  310. Save only the state as JSON:
  311. ```python
  312. callback.log_best_model(
  313. optimized_program, save_program=False, filetype="json"
  314. )
  315. ```
  316. """
  317. # Derive metadata to help discoverability in the UI
  318. info_dict = self._extract_program_info(model)
  319. metadata = {
  320. "dspy_version": getattr(dspy, "__version__", "unknown"),
  321. "module_class": model.__class__.__name__,
  322. **info_dict,
  323. }
  324. artifact = wandb.Artifact(
  325. name=f"{artifact_name}-{self._run.id}",
  326. type="model",
  327. metadata=metadata,
  328. )
  329. # Resolve and normalize the save directory in a cross-platform way
  330. if save_dir is None:
  331. save_dir = os.path.join(self._run.dir, "dspy_program")
  332. save_dir = os.path.normpath(save_dir)
  333. try:
  334. os.makedirs(save_dir, exist_ok=True)
  335. except Exception as exc:
  336. wandb.termwarn(
  337. f"Could not create or access directory '{save_dir}': {exc}. Skipping artifact logging."
  338. )
  339. return
  340. # Save per requested mode
  341. if save_program:
  342. model.save(save_dir, save_program=True)
  343. artifact.add_dir(save_dir)
  344. else:
  345. filename = f"program.{filetype}"
  346. file_path = os.path.join(save_dir, filename)
  347. model.save(file_path, save_program=False)
  348. artifact.add_file(file_path)
  349. self._run.log_artifact(artifact, aliases=list(aliases))