| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- from __future__ import annotations
- import datetime
- import io
- import logging
- from collections.abc import Sequence
- from dataclasses import asdict, dataclass
- from typing import Any
- import wandb
- from wandb.sdk.data_types import trace_tree
- from wandb.sdk.integration_utils.auto_logging import Response
- logger = logging.getLogger(__name__)
- @dataclass
- class UsageMetrics:
- elapsed_time: float = None
- prompt_tokens: int = None
- completion_tokens: int = None
- total_tokens: int = None
- @dataclass
- class Metrics:
- usage: UsageMetrics = None
- stats: wandb.Table = None
- trace: trace_tree.WBTraceTree = None
- usage_metric_keys = {f"usage/{k}" for k in asdict(UsageMetrics())}
- class OpenAIRequestResponseResolver:
- def __init__(self):
- self.define_metrics_called = False
- def __call__(
- self,
- args: Sequence[Any],
- kwargs: dict[str, Any],
- response: Response,
- start_time: float, # pass to comply with the protocol, but use response["created"] instead
- time_elapsed: float,
- ) -> dict[str, Any] | None:
- request = kwargs
- if not self.define_metrics_called:
- # define metrics on first call
- for key in usage_metric_keys:
- wandb.define_metric(key, step_metric="_timestamp")
- self.define_metrics_called = True
- try:
- if response.get("object") == "edit":
- return self._resolve_edit(request, response, time_elapsed)
- elif response.get("object") == "text_completion":
- return self._resolve_completion(request, response, time_elapsed)
- elif response.get("object") == "chat.completion":
- return self._resolve_chat_completion(request, response, time_elapsed)
- else:
- # todo: properly treat failed requests
- logger.info(
- f"Unsupported OpenAI response object: {response.get('object')}"
- )
- except Exception as e:
- logger.warning(f"Failed to resolve request/response: {e}")
- return None
- @staticmethod
- def results_to_trace_tree(
- request: dict[str, Any],
- response: Response,
- results: list[trace_tree.Result],
- time_elapsed: float,
- ) -> trace_tree.WBTraceTree:
- """Converts the request, response, and results into a trace tree.
- params:
- request: The request dictionary
- response: The response object
- results: A list of results object
- time_elapsed: The time elapsed in seconds
- returns:
- A wandb trace tree object.
- """
- start_time_ms = int(round(response["created"] * 1000))
- end_time_ms = start_time_ms + int(round(time_elapsed * 1000))
- span = trace_tree.Span(
- name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}",
- attributes=dict(response), # type: ignore
- start_time_ms=start_time_ms,
- end_time_ms=end_time_ms,
- span_kind=trace_tree.SpanKind.LLM,
- results=results,
- )
- model_obj = {"request": request, "response": response, "_kind": "openai"}
- return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj)
- def _resolve_edit(
- self,
- request: dict[str, Any],
- response: Response,
- time_elapsed: float,
- ) -> dict[str, Any]:
- """Resolves the request and response objects for `openai.Edit`."""
- request_str = (
- f"\n\n**Instruction**: {request['instruction']}\n\n"
- f"**Input**: {request['input']}\n"
- )
- choices = [
- f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"]
- ]
- return self._resolve_metrics(
- request=request,
- response=response,
- request_str=request_str,
- choices=choices,
- time_elapsed=time_elapsed,
- )
- def _resolve_completion(
- self,
- request: dict[str, Any],
- response: Response,
- time_elapsed: float,
- ) -> dict[str, Any]:
- """Resolves the request and response objects for `openai.Completion`."""
- request_str = f"\n\n**Prompt**: {request['prompt']}\n"
- choices = [
- f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"]
- ]
- return self._resolve_metrics(
- request=request,
- response=response,
- request_str=request_str,
- choices=choices,
- time_elapsed=time_elapsed,
- )
- def _resolve_chat_completion(
- self,
- request: dict[str, Any],
- response: Response,
- time_elapsed: float,
- ) -> dict[str, Any]:
- """Resolves the request and response objects for `openai.Completion`."""
- prompt = io.StringIO()
- for message in request["messages"]:
- prompt.write(f"\n\n**{message['role']}**: {message['content']}\n")
- request_str = prompt.getvalue()
- choices = [
- f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n"
- for choice in response["choices"]
- ]
- return self._resolve_metrics(
- request=request,
- response=response,
- request_str=request_str,
- choices=choices,
- time_elapsed=time_elapsed,
- )
- def _resolve_metrics(
- self,
- request: dict[str, Any],
- response: Response,
- request_str: str,
- choices: list[str],
- time_elapsed: float,
- ) -> dict[str, Any]:
- """Resolves the request and response objects for `openai.Completion`."""
- results = [
- trace_tree.Result(
- inputs={"request": request_str},
- outputs={"response": choice},
- )
- for choice in choices
- ]
- metrics = self._get_metrics_to_log(request, response, results, time_elapsed)
- return self._convert_metrics_to_dict(metrics)
- @staticmethod
- def _get_usage_metrics(response: Response, time_elapsed: float) -> UsageMetrics:
- """Gets the usage stats from the response object."""
- if response.get("usage"):
- usage_stats = UsageMetrics(**response["usage"])
- else:
- usage_stats = UsageMetrics()
- usage_stats.elapsed_time = time_elapsed
- return usage_stats
- def _get_metrics_to_log(
- self,
- request: dict[str, Any],
- response: Response,
- results: list[Any],
- time_elapsed: float,
- ) -> Metrics:
- model = response.get("model") or request.get("model")
- usage_metrics = self._get_usage_metrics(response, time_elapsed)
- usage = []
- for result in results:
- row = {
- "request": result.inputs["request"],
- "response": result.outputs["response"],
- "model": model,
- "start_time": datetime.datetime.fromtimestamp(response["created"]),
- "end_time": datetime.datetime.fromtimestamp(
- response["created"] + time_elapsed
- ),
- "request_id": response.get("id", None),
- "api_type": response.get("api_type", "openai"),
- "session_id": wandb.run.id,
- }
- row.update(asdict(usage_metrics))
- usage.append(row)
- usage_table = wandb.Table(
- columns=list(usage[0].keys()),
- data=[(item.values()) for item in usage],
- )
- trace = self.results_to_trace_tree(request, response, results, time_elapsed)
- metrics = Metrics(stats=usage_table, trace=trace, usage=usage_metrics)
- return metrics
- @staticmethod
- def _convert_metrics_to_dict(metrics: Metrics) -> dict[str, Any]:
- """Converts metrics to a dict."""
- metrics_dict = {
- "stats": metrics.stats,
- "trace": metrics.trace,
- }
- usage_stats = {f"usage/{k}": v for k, v in asdict(metrics.usage).items()}
- metrics_dict.update(usage_stats)
- return metrics_dict
|