"""DSPy ↔ Weights & Biases integration.""" from __future__ import annotations import logging import os from collections.abc import Mapping, Sequence from typing import Any, Literal import wandb import wandb.util from wandb.sdk.lib import telemetry from wandb.sdk.wandb_run import Run dspy = wandb.util.get_module( name="dspy", required=( "To use the W&B DSPy integration you need to have the `dspy` " "python package installed. Install it with `uv pip install dspy`." ), lazy=False, ) if dspy is not None: assert dspy.__version__ >= "3.0.0", ( "DSPy 3.0.0 or higher is required. You have " + dspy.__version__ ) logger = logging.getLogger(__name__) def _flatten_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: """Flatten a list of nested row dicts into flat key/value dicts. Args: rows (list[dict[str, Any]]): List of nested dictionaries to flatten. Returns: list[dict[str, Any]]: List of flattened dictionaries. """ def _flatten( d: dict[str, Any], parent_key: str = "", sep: str = "." ) -> dict[str, Any]: items = [] for k, v in d.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): items.extend(_flatten(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items) return [_flatten(row) for row in rows] class WandbDSPyCallback(dspy.utils.BaseCallback): """W&B callback for tracking DSPy evaluation and optimization. This callback logs evaluation scores, per-step predictions (optional), and a table capturing the DSPy program signature over time. It can also save the best program as a W&B Artifact for reproducibility. Examples: Basic usage within DSPy settings: ```python import dspy import wandb from wandb.integration.dspy import WandbDSPyCallback with wandb.init(project="dspy-optimization") as run: dspy.settings.callbacks.append(WandbDSPyCallback(run=run)) # Run your DSPy optimization/evaluation ``` """ def __init__(self, log_results: bool = True, run: Run | None = None) -> None: """Initialize the callback. Args: log_results (bool): Whether to log per-evaluation prediction tables. run (Run | None): Optional W&B run to use. Defaults to the current global run if available. Raises: wandb.Error: If no active run is provided or found. """ # If no run is provided, use the current global run if available. if run is None: if wandb.run is None: raise wandb.Error( "You must call `wandb.init()` before instantiating WandbDSPyCallback()." ) run = wandb.run self.log_results = log_results with telemetry.context(run=run) as tel: tel.feature.dspy_callback = True self._run = run self._did_log_config: bool = False self._program_info: dict[str, Any] = {} self._program_table: wandb.Table | None = None self._row_idx: int = 0 def _flatten_dict( self, nested: Any, parent_key: str = "", sep: str = "." ) -> dict[str, Any]: """Recursively flatten arbitrarily nested mappings and sequences. Args: nested (Any): Nested structure of mappings/lists to flatten. parent_key (str): Prefix to prepend to keys in the flattened output. sep (str): Key separator for nested fields. Returns: dict[str, Any]: Flattened dictionary representation. """ flat: dict[str, Any] = {} def _walk(obj: Any, base: str) -> None: if isinstance(obj, Mapping): for k, v in obj.items(): new_key = f"{base}{sep}{k}" if base else str(k) _walk(v, new_key) elif isinstance(obj, Sequence) and not isinstance( obj, (str, bytes, bytearray) ): for idx, v in enumerate(obj): new_key = f"{base}{sep}{idx}" if base else str(idx) _walk(v, new_key) else: # Base can be empty only if the top-level is a scalar; guard against that. key = base if base else "" if key: flat[key] = obj _walk(nested, parent_key) return flat def _extract_fields(self, fields: list[dict[str, Any]]) -> dict[str, str]: """Convert signature fields to a flat mapping of strings. Note: The input is expected to be a dict-like mapping from field names to field metadata. Values are stringified for logging. Args: fields (list[dict[str, Any]]): Mapping of field name to metadata. Returns: dict[str, str]: Mapping of field name to string value. """ return {k: str(v) for k, v in fields.items()} def _extract_program_info(self, program_obj: Any) -> dict[str, Any]: """Extract signature-related info from a DSPy program. Attempts to read the program signature, instructions, input and output fields from a DSPy `Predict` parameter if available. Args: program_obj (Any): DSPy program/module instance. Returns: dict[str, Any]: Flattened dictionary of signature metadata. """ info_dict = {} if program_obj is None: return info_dict try: sig = next( param.signature for _, param in program_obj.named_parameters() if isinstance(param, dspy.Predict) ) if getattr(sig, "signature", None): info_dict["signature"] = sig.signature if getattr(sig, "instructions", None): info_dict["instructions"] = sig.instructions if getattr(sig, "input_fields", None): input_fields = sig.input_fields info_dict["input_fields"] = self._extract_fields(input_fields) if getattr(sig, "output_fields", None): output_fields = sig.output_fields info_dict["output_fields"] = self._extract_fields(output_fields) return self._flatten_dict(info_dict) except Exception as e: logger.warning( "Failed to extract program info from Evaluate instance: %s", e ) return info_dict def on_evaluate_start( self, call_id: str, instance: Any, inputs: dict[str, Any], ) -> None: """Handle start of a DSPy evaluation call. Logs non-private fields from the evaluator instance to W&B config and captures program signature info for later logging. Args: call_id (str): Unique identifier for the evaluation call. instance (Any): The evaluation instance (e.g., `dspy.Evaluate`). inputs (dict[str, Any]): Inputs passed to the evaluation (may include a `program` key with the DSPy program). """ if not self._did_log_config: instance_vars = vars(instance) if hasattr(instance, "__dict__") else {} serializable = { k: v for k, v in instance_vars.items() if not k.startswith("_") } if "devset" in serializable: # we don't want to log the devset in the config del serializable["devset"] self._run.config.update(serializable) self._did_log_config = True # 2) Build/append program signature tables from the 'program' inputs if program_obj := inputs.get("program"): self._program_info = self._extract_program_info(program_obj) def on_evaluate_end( self, call_id: str, outputs: Any | None, exception: Exception | None = None, ) -> None: """Handle end of a DSPy evaluation call. If available, logs a numeric `score` metric and (optionally) per-step prediction tables. Always appends a row to the program-signature table. Args: call_id (str): Unique identifier for the evaluation call. outputs (Any | None): Evaluation outputs; supports `dspy.evaluate.evaluate.EvaluationResult`. exception (Exception | None): Exception raised during evaluation, if any. """ # The `BaseCallback` does not define the interface for the `outputs` parameter, # Currently, we know of `EvaluationResult` which is a subclass of `dspy.Prediction`. # We currently support this type and will warn the user if a different type is passed. score: float | None = None if exception is None: if isinstance(outputs, dspy.evaluate.evaluate.EvaluationResult): # log the float score as a wandb metric score = outputs.score wandb.log({"score": float(score)}, step=self._row_idx) # Log the predictions as a separate table for each eval end. # We know that results if of type `list[tuple["dspy.Example", "dspy.Example", Any]]` results = outputs.results if self.log_results: rows = self._parse_results(results) if rows: self._log_predictions_table(rows) else: wandb.termwarn( f"on_evaluate_end received unexpected outputs type: {type(outputs)}. " "Expected dspy.evaluate.evaluate.EvaluationResult; skipping logging score and `log_results`." ) else: wandb.termwarn( f"on_evaluate_end received exception: {exception}. " "Skipping logging score and `log_results`." ) # Log the program signature iteratively if self._program_table is None: columns = ["step", *self._program_info.keys()] if isinstance(score, float): columns.append("score") self._program_table = wandb.Table(columns=columns, log_mode="INCREMENTAL") if self._program_table is not None: values = list(self._program_info.values()) if isinstance(score, float): values.append(score) self._program_table.add_data( self._row_idx, *values, ) self._run.log( {"program_signature": self._program_table}, step=self._row_idx ) self._row_idx += 1 def _parse_results( self, results: list[tuple[dspy.Example, dspy.Prediction | dspy.Completions, bool]], ) -> list[dict[str, Any]]: """Normalize evaluation results into serializable row dicts. Args: results (list[tuple]): Sequence of `(example, prediction, is_correct)` tuples from DSPy evaluation. Returns: list[dict[str, Any]]: Rows with `example`, `prediction`, `is_correct`. """ _rows: list[dict[str, Any]] = [] for example, prediction, is_correct in results: if isinstance(prediction, dspy.Prediction): prediction_dict = prediction.toDict() if isinstance(prediction, dspy.Completions): prediction_dict = prediction.items() row: dict[str, Any] = { "example": example.toDict(), "prediction": prediction_dict, "is_correct": is_correct, } _rows.append(row) return _rows def _log_predictions_table(self, rows: list[dict[str, Any]]) -> None: """Log a W&B Table of predictions for the current evaluation step. Args: rows (list[dict[str, Any]]): Prediction rows to log. """ rows = _flatten_rows(rows) columns = list(rows[0].keys()) data: list[list[Any]] = [list(row.values()) for row in rows] preds_table = wandb.Table(columns=columns, data=data, log_mode="IMMUTABLE") self._run.log({f"predictions_{self._row_idx}": preds_table}, step=self._row_idx) def log_best_model( self, model: dspy.Module, *, save_program: bool = True, save_dir: str | None = None, filetype: Literal["json", "pkl"] = "json", aliases: Sequence[str] = ("best", "latest"), artifact_name: str = "dspy-program", ) -> None: """Save and log the best DSPy program as a W&B Artifact. You can choose to save the full program (architecture + state) or only the state to a single file (JSON or pickle). Args: model (dspy.Module): DSPy module to save. save_program (bool): Save full program directory if True; otherwise save only the state file. Defaults to `True`. save_dir (str): Directory to store program files before logging. Defaults to a subdirectory `dspy_program` within the active run's files directory (i.e., `wandb.run.dir`). filetype (Literal["json", "pkl"]): State file format when `save_program` is False. Defaults to `json`. aliases (Sequence[str]): Aliases for the logged Artifact version. Defaults to `("best", "latest")`. artifact_name (str): Base name for the Artifact. Defaults to `dspy-program`. Examples: Save the complete program and add aliases: ```python callback.log_best_model( optimized_program, save_program=True, aliases=("best", "production") ) ``` Save only the state as JSON: ```python callback.log_best_model( optimized_program, save_program=False, filetype="json" ) ``` """ # Derive metadata to help discoverability in the UI info_dict = self._extract_program_info(model) metadata = { "dspy_version": getattr(dspy, "__version__", "unknown"), "module_class": model.__class__.__name__, **info_dict, } artifact = wandb.Artifact( name=f"{artifact_name}-{self._run.id}", type="model", metadata=metadata, ) # Resolve and normalize the save directory in a cross-platform way if save_dir is None: save_dir = os.path.join(self._run.dir, "dspy_program") save_dir = os.path.normpath(save_dir) try: os.makedirs(save_dir, exist_ok=True) except Exception as exc: wandb.termwarn( f"Could not create or access directory '{save_dir}': {exc}. Skipping artifact logging." ) return # Save per requested mode if save_program: model.save(save_dir, save_program=True) artifact.add_dir(save_dir) else: filename = f"program.{filetype}" file_path = os.path.join(save_dir, filename) model.save(file_path, save_program=False) artifact.add_file(file_path) self._run.log_artifact(artifact, aliases=list(aliases))