| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423 |
- """DSPy ↔ Weights & Biases integration."""
- from __future__ import annotations
- import logging
- import os
- from collections.abc import Mapping, Sequence
- from typing import Any, Literal
- import wandb
- import wandb.util
- from wandb.sdk.lib import telemetry
- from wandb.sdk.wandb_run import Run
- dspy = wandb.util.get_module(
- name="dspy",
- required=(
- "To use the W&B DSPy integration you need to have the `dspy` "
- "python package installed. Install it with `uv pip install dspy`."
- ),
- lazy=False,
- )
- if dspy is not None:
- assert dspy.__version__ >= "3.0.0", (
- "DSPy 3.0.0 or higher is required. You have " + dspy.__version__
- )
- logger = logging.getLogger(__name__)
- def _flatten_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
- """Flatten a list of nested row dicts into flat key/value dicts.
- Args:
- rows (list[dict[str, Any]]): List of nested dictionaries to flatten.
- Returns:
- list[dict[str, Any]]: List of flattened dictionaries.
- """
- def _flatten(
- d: dict[str, Any], parent_key: str = "", sep: str = "."
- ) -> dict[str, Any]:
- items = []
- for k, v in d.items():
- new_key = f"{parent_key}{sep}{k}" if parent_key else k
- if isinstance(v, dict):
- items.extend(_flatten(v, new_key, sep=sep).items())
- else:
- items.append((new_key, v))
- return dict(items)
- return [_flatten(row) for row in rows]
- class WandbDSPyCallback(dspy.utils.BaseCallback):
- """W&B callback for tracking DSPy evaluation and optimization.
- This callback logs evaluation scores, per-step predictions (optional), and
- a table capturing the DSPy program signature over time. It can also save
- the best program as a W&B Artifact for reproducibility.
- Examples:
- Basic usage within DSPy settings:
- ```python
- import dspy
- import wandb
- from wandb.integration.dspy import WandbDSPyCallback
- with wandb.init(project="dspy-optimization") as run:
- dspy.settings.callbacks.append(WandbDSPyCallback(run=run))
- # Run your DSPy optimization/evaluation
- ```
- """
- def __init__(self, log_results: bool = True, run: Run | None = None) -> None:
- """Initialize the callback.
- Args:
- log_results (bool): Whether to log per-evaluation prediction tables.
- run (Run | None): Optional W&B run to use. Defaults to the
- current global run if available.
- Raises:
- wandb.Error: If no active run is provided or found.
- """
- # If no run is provided, use the current global run if available.
- if run is None:
- if wandb.run is None:
- raise wandb.Error(
- "You must call `wandb.init()` before instantiating WandbDSPyCallback()."
- )
- run = wandb.run
- self.log_results = log_results
- with telemetry.context(run=run) as tel:
- tel.feature.dspy_callback = True
- self._run = run
- self._did_log_config: bool = False
- self._program_info: dict[str, Any] = {}
- self._program_table: wandb.Table | None = None
- self._row_idx: int = 0
- def _flatten_dict(
- self, nested: Any, parent_key: str = "", sep: str = "."
- ) -> dict[str, Any]:
- """Recursively flatten arbitrarily nested mappings and sequences.
- Args:
- nested (Any): Nested structure of mappings/lists to flatten.
- parent_key (str): Prefix to prepend to keys in the flattened output.
- sep (str): Key separator for nested fields.
- Returns:
- dict[str, Any]: Flattened dictionary representation.
- """
- flat: dict[str, Any] = {}
- def _walk(obj: Any, base: str) -> None:
- if isinstance(obj, Mapping):
- for k, v in obj.items():
- new_key = f"{base}{sep}{k}" if base else str(k)
- _walk(v, new_key)
- elif isinstance(obj, Sequence) and not isinstance(
- obj, (str, bytes, bytearray)
- ):
- for idx, v in enumerate(obj):
- new_key = f"{base}{sep}{idx}" if base else str(idx)
- _walk(v, new_key)
- else:
- # Base can be empty only if the top-level is a scalar; guard against that.
- key = base if base else ""
- if key:
- flat[key] = obj
- _walk(nested, parent_key)
- return flat
- def _extract_fields(self, fields: list[dict[str, Any]]) -> dict[str, str]:
- """Convert signature fields to a flat mapping of strings.
- Note:
- The input is expected to be a dict-like mapping from field names to
- field metadata. Values are stringified for logging.
- Args:
- fields (list[dict[str, Any]]): Mapping of field name to metadata.
- Returns:
- dict[str, str]: Mapping of field name to string value.
- """
- return {k: str(v) for k, v in fields.items()}
- def _extract_program_info(self, program_obj: Any) -> dict[str, Any]:
- """Extract signature-related info from a DSPy program.
- Attempts to read the program signature, instructions, input and output
- fields from a DSPy `Predict` parameter if available.
- Args:
- program_obj (Any): DSPy program/module instance.
- Returns:
- dict[str, Any]: Flattened dictionary of signature metadata.
- """
- info_dict = {}
- if program_obj is None:
- return info_dict
- try:
- sig = next(
- param.signature
- for _, param in program_obj.named_parameters()
- if isinstance(param, dspy.Predict)
- )
- if getattr(sig, "signature", None):
- info_dict["signature"] = sig.signature
- if getattr(sig, "instructions", None):
- info_dict["instructions"] = sig.instructions
- if getattr(sig, "input_fields", None):
- input_fields = sig.input_fields
- info_dict["input_fields"] = self._extract_fields(input_fields)
- if getattr(sig, "output_fields", None):
- output_fields = sig.output_fields
- info_dict["output_fields"] = self._extract_fields(output_fields)
- return self._flatten_dict(info_dict)
- except Exception as e:
- logger.warning(
- "Failed to extract program info from Evaluate instance: %s", e
- )
- return info_dict
- def on_evaluate_start(
- self,
- call_id: str,
- instance: Any,
- inputs: dict[str, Any],
- ) -> None:
- """Handle start of a DSPy evaluation call.
- Logs non-private fields from the evaluator instance to W&B config and
- captures program signature info for later logging.
- Args:
- call_id (str): Unique identifier for the evaluation call.
- instance (Any): The evaluation instance (e.g., `dspy.Evaluate`).
- inputs (dict[str, Any]): Inputs passed to the evaluation (may
- include a `program` key with the DSPy program).
- """
- if not self._did_log_config:
- instance_vars = vars(instance) if hasattr(instance, "__dict__") else {}
- serializable = {
- k: v for k, v in instance_vars.items() if not k.startswith("_")
- }
- if "devset" in serializable:
- # we don't want to log the devset in the config
- del serializable["devset"]
- self._run.config.update(serializable)
- self._did_log_config = True
- # 2) Build/append program signature tables from the 'program' inputs
- if program_obj := inputs.get("program"):
- self._program_info = self._extract_program_info(program_obj)
- def on_evaluate_end(
- self,
- call_id: str,
- outputs: Any | None,
- exception: Exception | None = None,
- ) -> None:
- """Handle end of a DSPy evaluation call.
- If available, logs a numeric `score` metric and (optionally) per-step
- prediction tables. Always appends a row to the program-signature table.
- Args:
- call_id (str): Unique identifier for the evaluation call.
- outputs (Any | None): Evaluation outputs; supports
- `dspy.evaluate.evaluate.EvaluationResult`.
- exception (Exception | None): Exception raised during evaluation, if any.
- """
- # The `BaseCallback` does not define the interface for the `outputs` parameter,
- # Currently, we know of `EvaluationResult` which is a subclass of `dspy.Prediction`.
- # We currently support this type and will warn the user if a different type is passed.
- score: float | None = None
- if exception is None:
- if isinstance(outputs, dspy.evaluate.evaluate.EvaluationResult):
- # log the float score as a wandb metric
- score = outputs.score
- wandb.log({"score": float(score)}, step=self._row_idx)
- # Log the predictions as a separate table for each eval end.
- # We know that results if of type `list[tuple["dspy.Example", "dspy.Example", Any]]`
- results = outputs.results
- if self.log_results:
- rows = self._parse_results(results)
- if rows:
- self._log_predictions_table(rows)
- else:
- wandb.termwarn(
- f"on_evaluate_end received unexpected outputs type: {type(outputs)}. "
- "Expected dspy.evaluate.evaluate.EvaluationResult; skipping logging score and `log_results`."
- )
- else:
- wandb.termwarn(
- f"on_evaluate_end received exception: {exception}. "
- "Skipping logging score and `log_results`."
- )
- # Log the program signature iteratively
- if self._program_table is None:
- columns = ["step", *self._program_info.keys()]
- if isinstance(score, float):
- columns.append("score")
- self._program_table = wandb.Table(columns=columns, log_mode="INCREMENTAL")
- if self._program_table is not None:
- values = list(self._program_info.values())
- if isinstance(score, float):
- values.append(score)
- self._program_table.add_data(
- self._row_idx,
- *values,
- )
- self._run.log(
- {"program_signature": self._program_table}, step=self._row_idx
- )
- self._row_idx += 1
- def _parse_results(
- self,
- results: list[tuple[dspy.Example, dspy.Prediction | dspy.Completions, bool]],
- ) -> list[dict[str, Any]]:
- """Normalize evaluation results into serializable row dicts.
- Args:
- results (list[tuple]): Sequence of `(example, prediction, is_correct)`
- tuples from DSPy evaluation.
- Returns:
- list[dict[str, Any]]: Rows with `example`, `prediction`, `is_correct`.
- """
- _rows: list[dict[str, Any]] = []
- for example, prediction, is_correct in results:
- if isinstance(prediction, dspy.Prediction):
- prediction_dict = prediction.toDict()
- if isinstance(prediction, dspy.Completions):
- prediction_dict = prediction.items()
- row: dict[str, Any] = {
- "example": example.toDict(),
- "prediction": prediction_dict,
- "is_correct": is_correct,
- }
- _rows.append(row)
- return _rows
- def _log_predictions_table(self, rows: list[dict[str, Any]]) -> None:
- """Log a W&B Table of predictions for the current evaluation step.
- Args:
- rows (list[dict[str, Any]]): Prediction rows to log.
- """
- rows = _flatten_rows(rows)
- columns = list(rows[0].keys())
- data: list[list[Any]] = [list(row.values()) for row in rows]
- preds_table = wandb.Table(columns=columns, data=data, log_mode="IMMUTABLE")
- self._run.log({f"predictions_{self._row_idx}": preds_table}, step=self._row_idx)
- def log_best_model(
- self,
- model: dspy.Module,
- *,
- save_program: bool = True,
- save_dir: str | None = None,
- filetype: Literal["json", "pkl"] = "json",
- aliases: Sequence[str] = ("best", "latest"),
- artifact_name: str = "dspy-program",
- ) -> None:
- """Save and log the best DSPy program as a W&B Artifact.
- You can choose to save the full program (architecture + state) or only
- the state to a single file (JSON or pickle).
- Args:
- model (dspy.Module): DSPy module to save.
- save_program (bool): Save full program directory if True; otherwise
- save only the state file. Defaults to `True`.
- save_dir (str): Directory to store program files before logging. Defaults to a
- subdirectory `dspy_program` within the active run's files directory
- (i.e., `wandb.run.dir`).
- filetype (Literal["json", "pkl"]): State file format when
- `save_program` is False. Defaults to `json`.
- aliases (Sequence[str]): Aliases for the logged Artifact version. Defaults to `("best", "latest")`.
- artifact_name (str): Base name for the Artifact. Defaults to `dspy-program`.
- Examples:
- Save the complete program and add aliases:
- ```python
- callback.log_best_model(
- optimized_program, save_program=True, aliases=("best", "production")
- )
- ```
- Save only the state as JSON:
- ```python
- callback.log_best_model(
- optimized_program, save_program=False, filetype="json"
- )
- ```
- """
- # Derive metadata to help discoverability in the UI
- info_dict = self._extract_program_info(model)
- metadata = {
- "dspy_version": getattr(dspy, "__version__", "unknown"),
- "module_class": model.__class__.__name__,
- **info_dict,
- }
- artifact = wandb.Artifact(
- name=f"{artifact_name}-{self._run.id}",
- type="model",
- metadata=metadata,
- )
- # Resolve and normalize the save directory in a cross-platform way
- if save_dir is None:
- save_dir = os.path.join(self._run.dir, "dspy_program")
- save_dir = os.path.normpath(save_dir)
- try:
- os.makedirs(save_dir, exist_ok=True)
- except Exception as exc:
- wandb.termwarn(
- f"Could not create or access directory '{save_dir}': {exc}. Skipping artifact logging."
- )
- return
- # Save per requested mode
- if save_program:
- model.save(save_dir, save_program=True)
- artifact.add_dir(save_dir)
- else:
- filename = f"program.{filetype}"
- file_path = os.path.join(save_dir, filename)
- model.save(file_path, save_program=False)
- artifact.add_file(file_path)
- self._run.log_artifact(artifact, aliases=list(aliases))
|