resolver.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. from __future__ import annotations
  2. import logging
  3. from collections.abc import Sequence
  4. from datetime import datetime
  5. from typing import Any
  6. import wandb
  7. from wandb.sdk.integration_utils.auto_logging import Response
  8. from wandb.sdk.lib.runid import generate_id
  9. logger = logging.getLogger(__name__)
  10. def subset_dict(
  11. original_dict: dict[str, Any], keys_subset: Sequence[str]
  12. ) -> dict[str, Any]:
  13. """Create a subset of a dictionary using a subset of keys.
  14. :param original_dict: The original dictionary.
  15. :param keys_subset: The subset of keys to extract.
  16. :return: A dictionary containing only the specified keys.
  17. """
  18. return {key: original_dict[key] for key in keys_subset if key in original_dict}
  19. def reorder_and_convert_dict_list_to_table(
  20. data: list[dict[str, Any]], order: list[str]
  21. ) -> tuple[list[str], list[list[Any]]]:
  22. """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
  23. :param data: A list of dictionaries.
  24. :param order: A list of keys specifying the desired order for specific dictionaries. The remaining dictionaries will be ordered based on their original order.
  25. :return: A pair of column names and corresponding values.
  26. """
  27. final_columns = []
  28. keys_present = set()
  29. # First, add all ordered keys to the final columns
  30. for key in order:
  31. if key not in keys_present:
  32. final_columns.append(key)
  33. keys_present.add(key)
  34. # Then, add any keys present in the dictionaries but not in the order
  35. for d in data:
  36. for key in d:
  37. if key not in keys_present:
  38. final_columns.append(key)
  39. keys_present.add(key)
  40. # Then, construct the table of values
  41. values = []
  42. for d in data:
  43. row = []
  44. for key in final_columns:
  45. row.append(d.get(key, None))
  46. values.append(row)
  47. return final_columns, values
  48. def flatten_dict(
  49. dictionary: dict[str, Any], parent_key: str = "", sep: str = "-"
  50. ) -> dict[str, Any]:
  51. """Flatten a nested dictionary, joining keys using a specified separator.
  52. :param dictionary: The dictionary to flatten.
  53. :param parent_key: The base key to prepend to each key.
  54. :param sep: The separator to use when joining keys.
  55. :return: A flattened dictionary.
  56. """
  57. flattened_dict = {}
  58. for key, value in dictionary.items():
  59. new_key = f"{parent_key}{sep}{key}" if parent_key else key
  60. if isinstance(value, dict):
  61. flattened_dict.update(flatten_dict(value, new_key, sep=sep))
  62. else:
  63. flattened_dict[new_key] = value
  64. return flattened_dict
  65. def collect_common_keys(list_of_dicts: list[dict[str, Any]]) -> dict[str, list[Any]]:
  66. """Collect the common keys of a list of dictionaries. For each common key, put its values into a list in the order they appear in the original dictionaries.
  67. :param list_of_dicts: The list of dictionaries to inspect.
  68. :return: A dictionary with each common key and its corresponding list of values.
  69. """
  70. common_keys = set.intersection(*map(set, list_of_dicts))
  71. common_dict = {key: [] for key in common_keys}
  72. for d in list_of_dicts:
  73. for key in common_keys:
  74. common_dict[key].append(d[key])
  75. return common_dict
  76. class CohereRequestResponseResolver:
  77. """Class to resolve the request/response from the Cohere API and convert it to a dictionary that can be logged."""
  78. def __call__(
  79. self,
  80. args: Sequence[Any],
  81. kwargs: dict[str, Any],
  82. response: Response,
  83. start_time: float,
  84. time_elapsed: float,
  85. ) -> dict[str, Any] | None:
  86. """Process the response from the Cohere API and convert it to a dictionary that can be logged.
  87. :param args: The arguments of the original function.
  88. :param kwargs: The keyword arguments of the original function.
  89. :param response: The response from the Cohere API.
  90. :param start_time: The start time of the request.
  91. :param time_elapsed: The time elapsed for the request.
  92. :return: A dictionary containing the parsed response and timing information.
  93. """
  94. try:
  95. # Each of the different endpoints map to one specific response type
  96. # We want to 'type check' the response without directly importing the packages type
  97. # It may make more sense to pass the invoked symbol from the AutologAPI instead
  98. response_type = str(type(response)).split("'")[1].split(".")[-1]
  99. # Initialize parsed_response to None to handle the case where the response type is unsupported
  100. parsed_response = None
  101. if response_type == "Generations":
  102. parsed_response = self._resolve_generate_response(response)
  103. # TODO: Remove hard-coded default model name
  104. table_column_order = [
  105. "start_time",
  106. "query_id",
  107. "model",
  108. "prompt",
  109. "text",
  110. "token_likelihoods",
  111. "likelihood",
  112. "time_elapsed_(seconds)",
  113. "end_time",
  114. ]
  115. default_model = "command"
  116. elif response_type == "Chat":
  117. parsed_response = self._resolve_chat_response(response)
  118. table_column_order = [
  119. "start_time",
  120. "query_id",
  121. "model",
  122. "conversation_id",
  123. "response_id",
  124. "query",
  125. "text",
  126. "prompt",
  127. "preamble",
  128. "chat_history",
  129. "chatlog",
  130. "time_elapsed_(seconds)",
  131. "end_time",
  132. ]
  133. default_model = "command"
  134. elif response_type == "Classifications":
  135. parsed_response = self._resolve_classify_response(response)
  136. kwargs = self._resolve_classify_kwargs(kwargs)
  137. table_column_order = [
  138. "start_time",
  139. "query_id",
  140. "model",
  141. "id",
  142. "input",
  143. "prediction",
  144. "confidence",
  145. "time_elapsed_(seconds)",
  146. "end_time",
  147. ]
  148. default_model = "embed-english-v2.0"
  149. elif response_type == "SummarizeResponse":
  150. parsed_response = self._resolve_summarize_response(response)
  151. table_column_order = [
  152. "start_time",
  153. "query_id",
  154. "model",
  155. "response_id",
  156. "text",
  157. "additional_command",
  158. "summary",
  159. "time_elapsed_(seconds)",
  160. "end_time",
  161. "length",
  162. "format",
  163. ]
  164. default_model = "summarize-xlarge"
  165. elif response_type == "Reranking":
  166. parsed_response = self._resolve_rerank_response(response)
  167. table_column_order = [
  168. "start_time",
  169. "query_id",
  170. "model",
  171. "id",
  172. "query",
  173. "top_n",
  174. # This is a nested dict key that got flattened
  175. "document-text",
  176. "relevance_score",
  177. "index",
  178. "time_elapsed_(seconds)",
  179. "end_time",
  180. ]
  181. default_model = "rerank-english-v2.0"
  182. else:
  183. logger.info(f"Unsupported Cohere response object: {response}")
  184. return self._resolve(
  185. args,
  186. kwargs,
  187. parsed_response,
  188. start_time,
  189. time_elapsed,
  190. response_type,
  191. table_column_order,
  192. default_model,
  193. )
  194. except Exception as e:
  195. logger.warning(f"Failed to resolve request/response: {e}")
  196. return None
  197. # These helper functions process the response from different endpoints of the Cohere API.
  198. # Since the response objects for different endpoints have different structures,
  199. # we need different logic to process them.
  200. def _resolve_generate_response(self, response: Response) -> list[dict[str, Any]]:
  201. return_list = []
  202. for _response in response:
  203. # Built in Cohere.*.Generations function to color token_likelihoods and return a dict of response data
  204. _response_dict = _response._visualize_helper()
  205. try:
  206. _response_dict["token_likelihoods"] = wandb.Html(
  207. _response_dict["token_likelihoods"]
  208. )
  209. except (KeyError, ValueError):
  210. pass
  211. return_list.append(_response_dict)
  212. return return_list
  213. def _resolve_chat_response(self, response: Response) -> list[dict[str, Any]]:
  214. return [
  215. subset_dict(
  216. response.__dict__,
  217. [
  218. "response_id",
  219. "generation_id",
  220. "query",
  221. "text",
  222. "conversation_id",
  223. "prompt",
  224. "chatlog",
  225. "preamble",
  226. ],
  227. )
  228. ]
  229. def _resolve_classify_response(self, response: Response) -> list[dict[str, Any]]:
  230. # The labels key is a dict returning the scores for the classification probability for each label provided
  231. # We flatten this nested dict for ease of consumption in the wandb UI
  232. return [flatten_dict(_response.__dict__) for _response in response]
  233. def _resolve_classify_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
  234. # Example texts look strange when rendered in Wandb UI as it is a list of text and label
  235. # We extract each value into its own column
  236. example_texts = []
  237. example_labels = []
  238. for example in kwargs["examples"]:
  239. example_texts.append(example.text)
  240. example_labels.append(example.label)
  241. kwargs.pop("examples")
  242. kwargs["example_texts"] = example_texts
  243. kwargs["example_labels"] = example_labels
  244. return kwargs
  245. def _resolve_summarize_response(self, response: Response) -> list[dict[str, Any]]:
  246. return [{"response_id": response.id, "summary": response.summary}]
  247. def _resolve_rerank_response(self, response: Response) -> list[dict[str, Any]]:
  248. # The documents key contains a dict containing the content of the document which is at least "text"
  249. # We flatten this nested dict for ease of consumption in the wandb UI
  250. flattened_response_dicts = [
  251. flatten_dict(_response.__dict__) for _response in response
  252. ]
  253. # ReRank returns each document provided a top_n value so we aggregate into one view so users can paginate a row
  254. # As opposed to each row being one of the top_n responses
  255. return_dict = collect_common_keys(flattened_response_dicts)
  256. return_dict["id"] = response.id
  257. return [return_dict]
  258. def _resolve(
  259. self,
  260. args: Sequence[Any],
  261. kwargs: dict[str, Any],
  262. parsed_response: list[dict[str, Any]],
  263. start_time: float,
  264. time_elapsed: float,
  265. response_type: str,
  266. table_column_order: list[str],
  267. default_model: str,
  268. ) -> dict[str, Any]:
  269. """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
  270. :param args: The arguments passed to the API client.
  271. :param kwargs: The keyword arguments passed to the API client.
  272. :param parsed_response: The parsed response from the API.
  273. :param start_time: The start time of the API request.
  274. :param time_elapsed: The time elapsed during the API request.
  275. :param response_type: The type of the API response.
  276. :param table_column_order: The desired order of columns in the resulting table.
  277. :param default_model: The default model to use if not specified in the response.
  278. :return: A dictionary containing the formatted response.
  279. """
  280. # Args[0] is the client object where we can grab specific metadata about the underlying API status
  281. query_id = generate_id(length=16)
  282. parsed_args = subset_dict(
  283. args[0].__dict__,
  284. ["api_version", "batch_size", "max_retries", "num_workers", "timeout"],
  285. )
  286. start_time_dt = datetime.fromtimestamp(start_time)
  287. end_time_dt = datetime.fromtimestamp(start_time + time_elapsed)
  288. timings = {
  289. "start_time": start_time_dt,
  290. "end_time": end_time_dt,
  291. "time_elapsed_(seconds)": time_elapsed,
  292. }
  293. packed_data = []
  294. for _parsed_response in parsed_response:
  295. _packed_dict = {
  296. "query_id": query_id,
  297. **kwargs,
  298. **_parsed_response,
  299. **timings,
  300. **parsed_args,
  301. }
  302. if "model" not in _packed_dict:
  303. _packed_dict["model"] = default_model
  304. packed_data.append(_packed_dict)
  305. columns, data = reorder_and_convert_dict_list_to_table(
  306. packed_data, table_column_order
  307. )
  308. request_response_table = wandb.Table(data=data, columns=columns)
  309. return {f"{response_type}": request_response_table}