resolver.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. from __future__ import annotations
  2. import datetime
  3. import io
  4. import logging
  5. from collections.abc import Sequence
  6. from dataclasses import asdict, dataclass
  7. from typing import Any
  8. import wandb
  9. from wandb.sdk.data_types import trace_tree
  10. from wandb.sdk.integration_utils.auto_logging import Response
  11. logger = logging.getLogger(__name__)
  12. @dataclass
  13. class UsageMetrics:
  14. elapsed_time: float = None
  15. prompt_tokens: int = None
  16. completion_tokens: int = None
  17. total_tokens: int = None
  18. @dataclass
  19. class Metrics:
  20. usage: UsageMetrics = None
  21. stats: wandb.Table = None
  22. trace: trace_tree.WBTraceTree = None
  23. usage_metric_keys = {f"usage/{k}" for k in asdict(UsageMetrics())}
  24. class OpenAIRequestResponseResolver:
  25. def __init__(self):
  26. self.define_metrics_called = False
  27. def __call__(
  28. self,
  29. args: Sequence[Any],
  30. kwargs: dict[str, Any],
  31. response: Response,
  32. start_time: float, # pass to comply with the protocol, but use response["created"] instead
  33. time_elapsed: float,
  34. ) -> dict[str, Any] | None:
  35. request = kwargs
  36. if not self.define_metrics_called:
  37. # define metrics on first call
  38. for key in usage_metric_keys:
  39. wandb.define_metric(key, step_metric="_timestamp")
  40. self.define_metrics_called = True
  41. try:
  42. if response.get("object") == "edit":
  43. return self._resolve_edit(request, response, time_elapsed)
  44. elif response.get("object") == "text_completion":
  45. return self._resolve_completion(request, response, time_elapsed)
  46. elif response.get("object") == "chat.completion":
  47. return self._resolve_chat_completion(request, response, time_elapsed)
  48. else:
  49. # todo: properly treat failed requests
  50. logger.info(
  51. f"Unsupported OpenAI response object: {response.get('object')}"
  52. )
  53. except Exception as e:
  54. logger.warning(f"Failed to resolve request/response: {e}")
  55. return None
  56. @staticmethod
  57. def results_to_trace_tree(
  58. request: dict[str, Any],
  59. response: Response,
  60. results: list[trace_tree.Result],
  61. time_elapsed: float,
  62. ) -> trace_tree.WBTraceTree:
  63. """Converts the request, response, and results into a trace tree.
  64. params:
  65. request: The request dictionary
  66. response: The response object
  67. results: A list of results object
  68. time_elapsed: The time elapsed in seconds
  69. returns:
  70. A wandb trace tree object.
  71. """
  72. start_time_ms = int(round(response["created"] * 1000))
  73. end_time_ms = start_time_ms + int(round(time_elapsed * 1000))
  74. span = trace_tree.Span(
  75. name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}",
  76. attributes=dict(response), # type: ignore
  77. start_time_ms=start_time_ms,
  78. end_time_ms=end_time_ms,
  79. span_kind=trace_tree.SpanKind.LLM,
  80. results=results,
  81. )
  82. model_obj = {"request": request, "response": response, "_kind": "openai"}
  83. return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj)
  84. def _resolve_edit(
  85. self,
  86. request: dict[str, Any],
  87. response: Response,
  88. time_elapsed: float,
  89. ) -> dict[str, Any]:
  90. """Resolves the request and response objects for `openai.Edit`."""
  91. request_str = (
  92. f"\n\n**Instruction**: {request['instruction']}\n\n"
  93. f"**Input**: {request['input']}\n"
  94. )
  95. choices = [
  96. f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"]
  97. ]
  98. return self._resolve_metrics(
  99. request=request,
  100. response=response,
  101. request_str=request_str,
  102. choices=choices,
  103. time_elapsed=time_elapsed,
  104. )
  105. def _resolve_completion(
  106. self,
  107. request: dict[str, Any],
  108. response: Response,
  109. time_elapsed: float,
  110. ) -> dict[str, Any]:
  111. """Resolves the request and response objects for `openai.Completion`."""
  112. request_str = f"\n\n**Prompt**: {request['prompt']}\n"
  113. choices = [
  114. f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"]
  115. ]
  116. return self._resolve_metrics(
  117. request=request,
  118. response=response,
  119. request_str=request_str,
  120. choices=choices,
  121. time_elapsed=time_elapsed,
  122. )
  123. def _resolve_chat_completion(
  124. self,
  125. request: dict[str, Any],
  126. response: Response,
  127. time_elapsed: float,
  128. ) -> dict[str, Any]:
  129. """Resolves the request and response objects for `openai.Completion`."""
  130. prompt = io.StringIO()
  131. for message in request["messages"]:
  132. prompt.write(f"\n\n**{message['role']}**: {message['content']}\n")
  133. request_str = prompt.getvalue()
  134. choices = [
  135. f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n"
  136. for choice in response["choices"]
  137. ]
  138. return self._resolve_metrics(
  139. request=request,
  140. response=response,
  141. request_str=request_str,
  142. choices=choices,
  143. time_elapsed=time_elapsed,
  144. )
  145. def _resolve_metrics(
  146. self,
  147. request: dict[str, Any],
  148. response: Response,
  149. request_str: str,
  150. choices: list[str],
  151. time_elapsed: float,
  152. ) -> dict[str, Any]:
  153. """Resolves the request and response objects for `openai.Completion`."""
  154. results = [
  155. trace_tree.Result(
  156. inputs={"request": request_str},
  157. outputs={"response": choice},
  158. )
  159. for choice in choices
  160. ]
  161. metrics = self._get_metrics_to_log(request, response, results, time_elapsed)
  162. return self._convert_metrics_to_dict(metrics)
  163. @staticmethod
  164. def _get_usage_metrics(response: Response, time_elapsed: float) -> UsageMetrics:
  165. """Gets the usage stats from the response object."""
  166. if response.get("usage"):
  167. usage_stats = UsageMetrics(**response["usage"])
  168. else:
  169. usage_stats = UsageMetrics()
  170. usage_stats.elapsed_time = time_elapsed
  171. return usage_stats
  172. def _get_metrics_to_log(
  173. self,
  174. request: dict[str, Any],
  175. response: Response,
  176. results: list[Any],
  177. time_elapsed: float,
  178. ) -> Metrics:
  179. model = response.get("model") or request.get("model")
  180. usage_metrics = self._get_usage_metrics(response, time_elapsed)
  181. usage = []
  182. for result in results:
  183. row = {
  184. "request": result.inputs["request"],
  185. "response": result.outputs["response"],
  186. "model": model,
  187. "start_time": datetime.datetime.fromtimestamp(response["created"]),
  188. "end_time": datetime.datetime.fromtimestamp(
  189. response["created"] + time_elapsed
  190. ),
  191. "request_id": response.get("id", None),
  192. "api_type": response.get("api_type", "openai"),
  193. "session_id": wandb.run.id,
  194. }
  195. row.update(asdict(usage_metrics))
  196. usage.append(row)
  197. usage_table = wandb.Table(
  198. columns=list(usage[0].keys()),
  199. data=[(item.values()) for item in usage],
  200. )
  201. trace = self.results_to_trace_tree(request, response, results, time_elapsed)
  202. metrics = Metrics(stats=usage_table, trace=trace, usage=usage_metrics)
  203. return metrics
  204. @staticmethod
  205. def _convert_metrics_to_dict(metrics: Metrics) -> dict[str, Any]:
  206. """Converts metrics to a dict."""
  207. metrics_dict = {
  208. "stats": metrics.stats,
  209. "trace": metrics.trace,
  210. }
  211. usage_stats = {f"usage/{k}": v for k, v in asdict(metrics.usage).items()}
  212. metrics_dict.update(usage_stats)
  213. return metrics_dict