| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350 |
- from __future__ import annotations
- import logging
- from collections.abc import Sequence
- from datetime import datetime
- from typing import Any
- import wandb
- from wandb.sdk.integration_utils.auto_logging import Response
- from wandb.sdk.lib.runid import generate_id
- logger = logging.getLogger(__name__)
- def subset_dict(
- original_dict: dict[str, Any], keys_subset: Sequence[str]
- ) -> dict[str, Any]:
- """Create a subset of a dictionary using a subset of keys.
- :param original_dict: The original dictionary.
- :param keys_subset: The subset of keys to extract.
- :return: A dictionary containing only the specified keys.
- """
- return {key: original_dict[key] for key in keys_subset if key in original_dict}
- def reorder_and_convert_dict_list_to_table(
- data: list[dict[str, Any]], order: list[str]
- ) -> tuple[list[str], list[list[Any]]]:
- """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
- :param data: A list of dictionaries.
- :param order: A list of keys specifying the desired order for specific dictionaries. The remaining dictionaries will be ordered based on their original order.
- :return: A pair of column names and corresponding values.
- """
- final_columns = []
- keys_present = set()
- # First, add all ordered keys to the final columns
- for key in order:
- if key not in keys_present:
- final_columns.append(key)
- keys_present.add(key)
- # Then, add any keys present in the dictionaries but not in the order
- for d in data:
- for key in d:
- if key not in keys_present:
- final_columns.append(key)
- keys_present.add(key)
- # Then, construct the table of values
- values = []
- for d in data:
- row = []
- for key in final_columns:
- row.append(d.get(key, None))
- values.append(row)
- return final_columns, values
- def flatten_dict(
- dictionary: dict[str, Any], parent_key: str = "", sep: str = "-"
- ) -> dict[str, Any]:
- """Flatten a nested dictionary, joining keys using a specified separator.
- :param dictionary: The dictionary to flatten.
- :param parent_key: The base key to prepend to each key.
- :param sep: The separator to use when joining keys.
- :return: A flattened dictionary.
- """
- flattened_dict = {}
- for key, value in dictionary.items():
- new_key = f"{parent_key}{sep}{key}" if parent_key else key
- if isinstance(value, dict):
- flattened_dict.update(flatten_dict(value, new_key, sep=sep))
- else:
- flattened_dict[new_key] = value
- return flattened_dict
- def collect_common_keys(list_of_dicts: list[dict[str, Any]]) -> dict[str, list[Any]]:
- """Collect the common keys of a list of dictionaries. For each common key, put its values into a list in the order they appear in the original dictionaries.
- :param list_of_dicts: The list of dictionaries to inspect.
- :return: A dictionary with each common key and its corresponding list of values.
- """
- common_keys = set.intersection(*map(set, list_of_dicts))
- common_dict = {key: [] for key in common_keys}
- for d in list_of_dicts:
- for key in common_keys:
- common_dict[key].append(d[key])
- return common_dict
- class CohereRequestResponseResolver:
- """Class to resolve the request/response from the Cohere API and convert it to a dictionary that can be logged."""
- def __call__(
- self,
- args: Sequence[Any],
- kwargs: dict[str, Any],
- response: Response,
- start_time: float,
- time_elapsed: float,
- ) -> dict[str, Any] | None:
- """Process the response from the Cohere API and convert it to a dictionary that can be logged.
- :param args: The arguments of the original function.
- :param kwargs: The keyword arguments of the original function.
- :param response: The response from the Cohere API.
- :param start_time: The start time of the request.
- :param time_elapsed: The time elapsed for the request.
- :return: A dictionary containing the parsed response and timing information.
- """
- try:
- # Each of the different endpoints map to one specific response type
- # We want to 'type check' the response without directly importing the packages type
- # It may make more sense to pass the invoked symbol from the AutologAPI instead
- response_type = str(type(response)).split("'")[1].split(".")[-1]
- # Initialize parsed_response to None to handle the case where the response type is unsupported
- parsed_response = None
- if response_type == "Generations":
- parsed_response = self._resolve_generate_response(response)
- # TODO: Remove hard-coded default model name
- table_column_order = [
- "start_time",
- "query_id",
- "model",
- "prompt",
- "text",
- "token_likelihoods",
- "likelihood",
- "time_elapsed_(seconds)",
- "end_time",
- ]
- default_model = "command"
- elif response_type == "Chat":
- parsed_response = self._resolve_chat_response(response)
- table_column_order = [
- "start_time",
- "query_id",
- "model",
- "conversation_id",
- "response_id",
- "query",
- "text",
- "prompt",
- "preamble",
- "chat_history",
- "chatlog",
- "time_elapsed_(seconds)",
- "end_time",
- ]
- default_model = "command"
- elif response_type == "Classifications":
- parsed_response = self._resolve_classify_response(response)
- kwargs = self._resolve_classify_kwargs(kwargs)
- table_column_order = [
- "start_time",
- "query_id",
- "model",
- "id",
- "input",
- "prediction",
- "confidence",
- "time_elapsed_(seconds)",
- "end_time",
- ]
- default_model = "embed-english-v2.0"
- elif response_type == "SummarizeResponse":
- parsed_response = self._resolve_summarize_response(response)
- table_column_order = [
- "start_time",
- "query_id",
- "model",
- "response_id",
- "text",
- "additional_command",
- "summary",
- "time_elapsed_(seconds)",
- "end_time",
- "length",
- "format",
- ]
- default_model = "summarize-xlarge"
- elif response_type == "Reranking":
- parsed_response = self._resolve_rerank_response(response)
- table_column_order = [
- "start_time",
- "query_id",
- "model",
- "id",
- "query",
- "top_n",
- # This is a nested dict key that got flattened
- "document-text",
- "relevance_score",
- "index",
- "time_elapsed_(seconds)",
- "end_time",
- ]
- default_model = "rerank-english-v2.0"
- else:
- logger.info(f"Unsupported Cohere response object: {response}")
- return self._resolve(
- args,
- kwargs,
- parsed_response,
- start_time,
- time_elapsed,
- response_type,
- table_column_order,
- default_model,
- )
- except Exception as e:
- logger.warning(f"Failed to resolve request/response: {e}")
- return None
- # These helper functions process the response from different endpoints of the Cohere API.
- # Since the response objects for different endpoints have different structures,
- # we need different logic to process them.
- def _resolve_generate_response(self, response: Response) -> list[dict[str, Any]]:
- return_list = []
- for _response in response:
- # Built in Cohere.*.Generations function to color token_likelihoods and return a dict of response data
- _response_dict = _response._visualize_helper()
- try:
- _response_dict["token_likelihoods"] = wandb.Html(
- _response_dict["token_likelihoods"]
- )
- except (KeyError, ValueError):
- pass
- return_list.append(_response_dict)
- return return_list
- def _resolve_chat_response(self, response: Response) -> list[dict[str, Any]]:
- return [
- subset_dict(
- response.__dict__,
- [
- "response_id",
- "generation_id",
- "query",
- "text",
- "conversation_id",
- "prompt",
- "chatlog",
- "preamble",
- ],
- )
- ]
- def _resolve_classify_response(self, response: Response) -> list[dict[str, Any]]:
- # The labels key is a dict returning the scores for the classification probability for each label provided
- # We flatten this nested dict for ease of consumption in the wandb UI
- return [flatten_dict(_response.__dict__) for _response in response]
- def _resolve_classify_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
- # Example texts look strange when rendered in Wandb UI as it is a list of text and label
- # We extract each value into its own column
- example_texts = []
- example_labels = []
- for example in kwargs["examples"]:
- example_texts.append(example.text)
- example_labels.append(example.label)
- kwargs.pop("examples")
- kwargs["example_texts"] = example_texts
- kwargs["example_labels"] = example_labels
- return kwargs
- def _resolve_summarize_response(self, response: Response) -> list[dict[str, Any]]:
- return [{"response_id": response.id, "summary": response.summary}]
- def _resolve_rerank_response(self, response: Response) -> list[dict[str, Any]]:
- # The documents key contains a dict containing the content of the document which is at least "text"
- # We flatten this nested dict for ease of consumption in the wandb UI
- flattened_response_dicts = [
- flatten_dict(_response.__dict__) for _response in response
- ]
- # ReRank returns each document provided a top_n value so we aggregate into one view so users can paginate a row
- # As opposed to each row being one of the top_n responses
- return_dict = collect_common_keys(flattened_response_dicts)
- return_dict["id"] = response.id
- return [return_dict]
- def _resolve(
- self,
- args: Sequence[Any],
- kwargs: dict[str, Any],
- parsed_response: list[dict[str, Any]],
- start_time: float,
- time_elapsed: float,
- response_type: str,
- table_column_order: list[str],
- default_model: str,
- ) -> dict[str, Any]:
- """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
- :param args: The arguments passed to the API client.
- :param kwargs: The keyword arguments passed to the API client.
- :param parsed_response: The parsed response from the API.
- :param start_time: The start time of the API request.
- :param time_elapsed: The time elapsed during the API request.
- :param response_type: The type of the API response.
- :param table_column_order: The desired order of columns in the resulting table.
- :param default_model: The default model to use if not specified in the response.
- :return: A dictionary containing the formatted response.
- """
- # Args[0] is the client object where we can grab specific metadata about the underlying API status
- query_id = generate_id(length=16)
- parsed_args = subset_dict(
- args[0].__dict__,
- ["api_version", "batch_size", "max_retries", "num_workers", "timeout"],
- )
- start_time_dt = datetime.fromtimestamp(start_time)
- end_time_dt = datetime.fromtimestamp(start_time + time_elapsed)
- timings = {
- "start_time": start_time_dt,
- "end_time": end_time_dt,
- "time_elapsed_(seconds)": time_elapsed,
- }
- packed_data = []
- for _parsed_response in parsed_response:
- _packed_dict = {
- "query_id": query_id,
- **kwargs,
- **_parsed_response,
- **timings,
- **parsed_args,
- }
- if "model" not in _packed_dict:
- _packed_dict["model"] = default_model
- packed_data.append(_packed_dict)
- columns, data = reorder_and_convert_dict_list_to_table(
- packed_data, table_column_order
- )
- request_response_table = wandb.Table(data=data, columns=columns)
- return {f"{response_type}": request_response_table}
|