resolver.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. from collections.abc import Sequence
  5. from datetime import datetime
  6. from typing import Any
  7. import pytz
  8. import wandb
  9. from wandb.sdk.integration_utils.auto_logging import Response
  10. from wandb.sdk.lib.runid import generate_id
  11. logger = logging.getLogger(__name__)
  12. SUPPORTED_PIPELINE_TASKS = [
  13. "text-classification",
  14. "sentiment-analysis",
  15. "question-answering",
  16. "summarization",
  17. "translation",
  18. "text2text-generation",
  19. "text-generation",
  20. # "conversational",
  21. ]
  22. PIPELINES_WITH_TOP_K = [
  23. "text-classification",
  24. "sentiment-analysis",
  25. "question-answering",
  26. ]
  27. class HuggingFacePipelineRequestResponseResolver:
  28. """Resolver for HuggingFace's pipeline request and responses, providing necessary data transformations and formatting.
  29. This is based off (from wandb.sdk.integration_utils.auto_logging import RequestResponseResolver)
  30. """
  31. autolog_id = None
  32. def __call__(
  33. self,
  34. args: Sequence[Any],
  35. kwargs: dict[str, Any],
  36. response: Response,
  37. start_time: float,
  38. time_elapsed: float,
  39. ) -> dict[str, Any] | None:
  40. """Main call method for this class.
  41. :param args: list of arguments
  42. :param kwargs: dictionary of keyword arguments
  43. :param response: the response from the request
  44. :param start_time: time when request started
  45. :param time_elapsed: time elapsed for the request
  46. :returns: packed data as a dictionary for logging to wandb, None if an exception occurred
  47. """
  48. try:
  49. pipe, input_data = args[:2]
  50. task = pipe.task
  51. # Translation tasks are in the form of `translation_x_to_y`
  52. if task in SUPPORTED_PIPELINE_TASKS or task.startswith("translation"):
  53. model = self._get_model(pipe)
  54. if model is None:
  55. return None
  56. model_alias = model.name_or_path
  57. timestamp = datetime.now(pytz.utc)
  58. input_data, response = self._transform_task_specific_data(
  59. task, input_data, response
  60. )
  61. formatted_data = self._format_data(task, input_data, response, kwargs)
  62. packed_data = self._create_table(
  63. formatted_data, model_alias, timestamp, time_elapsed
  64. )
  65. table_name = os.environ.get("WANDB_AUTOLOG_TABLE_NAME", f"{task}")
  66. # TODO: Let users decide the name in a way that does not use an environment variable
  67. return {
  68. table_name: wandb.Table(
  69. columns=packed_data[0], data=packed_data[1:]
  70. )
  71. }
  72. logger.warning(
  73. f"The task: `{task}` is not yet supported.\nPlease contact `wandb` to notify us if you would like support for this task"
  74. )
  75. except Exception as e:
  76. logger.warning(e)
  77. return None
  78. # TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel)
  79. # from transformers.modeling_utils import PreTrainedModel
  80. # We do not want this dependency explicitly in our codebase so we make a very general
  81. # assumption about the structure of the pipeline which may have unintended consequences
  82. def _get_model(self, pipe) -> Any | None:
  83. """Extracts model from the pipeline.
  84. :param pipe: the HuggingFace pipeline
  85. :returns: Model if available, None otherwise
  86. """
  87. model = pipe.model
  88. try:
  89. return model.model
  90. except AttributeError:
  91. logger.info(
  92. "Model does not have a `.model` attribute. Assuming `pipe.model` is the correct model."
  93. )
  94. return model
  95. @staticmethod
  96. def _transform_task_specific_data(
  97. task: str, input_data: list[Any] | Any, response: list[Any] | Any
  98. ) -> tuple[list[Any] | Any, list[Any] | Any]:
  99. """Transform input and response data based on specific tasks.
  100. :param task: the task name
  101. :param input_data: the input data
  102. :param response: the response data
  103. :returns: tuple of transformed input_data and response
  104. """
  105. if task == "question-answering":
  106. input_data = input_data if isinstance(input_data, list) else [input_data]
  107. input_data = [data.__dict__ for data in input_data]
  108. elif task == "conversational":
  109. # We only grab the latest input/output pair from the conversation
  110. # Logging the whole conversation renders strangely.
  111. input_data = input_data if isinstance(input_data, list) else [input_data]
  112. input_data = [data.__dict__["past_user_inputs"][-1] for data in input_data]
  113. response = response if isinstance(response, list) else [response]
  114. response = [data.__dict__["generated_responses"][-1] for data in response]
  115. return input_data, response
  116. def _format_data(
  117. self,
  118. task: str,
  119. input_data: list[Any] | Any,
  120. response: list[Any] | Any,
  121. kwargs: dict[str, Any],
  122. ) -> list[dict[str, Any]]:
  123. """Formats input data, response, and kwargs into a list of dictionaries.
  124. :param task: the task name
  125. :param input_data: the input data
  126. :param response: the response data
  127. :param kwargs: dictionary of keyword arguments
  128. :returns: list of dictionaries containing formatted data
  129. """
  130. input_data = input_data if isinstance(input_data, list) else [input_data]
  131. response = response if isinstance(response, list) else [response]
  132. formatted_data = []
  133. for i_text, r_text in zip(input_data, response):
  134. # Unpack single element responses for better rendering in wandb UI when it is a task without top_k
  135. # top_k = 1 would unpack the response into a single element while top_k > 1 would be a list
  136. # this would cause the UI to not properly concatenate the tables of the same task by omitting the elements past the first
  137. if (
  138. (isinstance(r_text, list))
  139. and (len(r_text) == 1)
  140. and task not in PIPELINES_WITH_TOP_K
  141. ):
  142. r_text = r_text[0]
  143. formatted_data.append(
  144. {"input": i_text, "response": r_text, "kwargs": kwargs}
  145. )
  146. return formatted_data
  147. def _create_table(
  148. self,
  149. formatted_data: list[dict[str, Any]],
  150. model_alias: str,
  151. timestamp: float,
  152. time_elapsed: float,
  153. ) -> list[list[Any]]:
  154. """Creates a table from formatted data, model alias, timestamp, and elapsed time.
  155. :param formatted_data: list of dictionaries containing formatted data
  156. :param model_alias: alias of the model
  157. :param timestamp: timestamp of the data
  158. :param time_elapsed: time elapsed from the beginning
  159. :returns: list of lists, representing a table of data. [0]th element = columns. [1]st element = data
  160. """
  161. header = [
  162. "ID",
  163. "Model Alias",
  164. "Timestamp",
  165. "Elapsed Time",
  166. "Input",
  167. "Response",
  168. "Kwargs",
  169. ]
  170. table = [header]
  171. autolog_id = generate_id(length=16)
  172. for data in formatted_data:
  173. row = [
  174. autolog_id,
  175. model_alias,
  176. timestamp,
  177. time_elapsed,
  178. data["input"],
  179. data["response"],
  180. data["kwargs"],
  181. ]
  182. table.append(row)
  183. self.autolog_id = autolog_id
  184. return table
  185. def get_latest_id(self):
  186. return self.autolog_id