huggingface_hub.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import inspect
  2. import sys
  3. from functools import wraps
  4. from typing import TYPE_CHECKING
  5. import sentry_sdk
  6. from sentry_sdk.ai.monitoring import record_token_usage
  7. from sentry_sdk.ai.utils import set_data_normalized
  8. from sentry_sdk.consts import OP, SPANDATA
  9. from sentry_sdk.integrations import DidNotEnable, Integration
  10. from sentry_sdk.scope import should_send_default_pii
  11. from sentry_sdk.utils import (
  12. capture_internal_exceptions,
  13. event_from_exception,
  14. reraise,
  15. )
  16. if TYPE_CHECKING:
  17. from typing import Any, Callable, Iterable
  18. try:
  19. import huggingface_hub.inference._client
  20. except ImportError:
  21. raise DidNotEnable("Huggingface not installed")
  22. class HuggingfaceHubIntegration(Integration):
  23. identifier = "huggingface_hub"
  24. origin = f"auto.ai.{identifier}"
  25. def __init__(
  26. self: "HuggingfaceHubIntegration", include_prompts: bool = True
  27. ) -> None:
  28. self.include_prompts = include_prompts
  29. @staticmethod
  30. def setup_once() -> None:
  31. # Other tasks that can be called: https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks
  32. huggingface_hub.inference._client.InferenceClient.text_generation = (
  33. _wrap_huggingface_task(
  34. huggingface_hub.inference._client.InferenceClient.text_generation,
  35. OP.GEN_AI_TEXT_COMPLETION,
  36. )
  37. )
  38. huggingface_hub.inference._client.InferenceClient.chat_completion = (
  39. _wrap_huggingface_task(
  40. huggingface_hub.inference._client.InferenceClient.chat_completion,
  41. OP.GEN_AI_CHAT,
  42. )
  43. )
  44. def _capture_exception(exc: "Any") -> None:
  45. event, hint = event_from_exception(
  46. exc,
  47. client_options=sentry_sdk.get_client().options,
  48. mechanism={"type": "huggingface_hub", "handled": False},
  49. )
  50. sentry_sdk.capture_event(event, hint=hint)
  51. def _wrap_huggingface_task(f: "Callable[..., Any]", op: str) -> "Callable[..., Any]":
  52. @wraps(f)
  53. def new_huggingface_task(*args: "Any", **kwargs: "Any") -> "Any":
  54. integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
  55. if integration is None:
  56. return f(*args, **kwargs)
  57. prompt = None
  58. if "prompt" in kwargs:
  59. prompt = kwargs["prompt"]
  60. elif "messages" in kwargs:
  61. prompt = kwargs["messages"]
  62. elif len(args) >= 2:
  63. if isinstance(args[1], str) or isinstance(args[1], list):
  64. prompt = args[1]
  65. if prompt is None:
  66. # invalid call, dont instrument, let it return error
  67. return f(*args, **kwargs)
  68. client = args[0]
  69. model = client.model or kwargs.get("model") or ""
  70. operation_name = op.split(".")[-1]
  71. span = sentry_sdk.start_span(
  72. op=op,
  73. name=f"{operation_name} {model}",
  74. origin=HuggingfaceHubIntegration.origin,
  75. )
  76. span.__enter__()
  77. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, operation_name)
  78. if model:
  79. span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
  80. # Input attributes
  81. if should_send_default_pii() and integration.include_prompts:
  82. set_data_normalized(
  83. span, SPANDATA.GEN_AI_REQUEST_MESSAGES, prompt, unpack=False
  84. )
  85. attribute_mapping = {
  86. "tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
  87. "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
  88. "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
  89. "presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
  90. "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
  91. "top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
  92. "top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
  93. "stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
  94. }
  95. for attribute, span_attribute in attribute_mapping.items():
  96. value = kwargs.get(attribute, None)
  97. if value is not None:
  98. if isinstance(value, (int, float, bool, str)):
  99. span.set_data(span_attribute, value)
  100. else:
  101. set_data_normalized(span, span_attribute, value, unpack=False)
  102. # LLM Execution
  103. try:
  104. res = f(*args, **kwargs)
  105. except Exception as e:
  106. exc_info = sys.exc_info()
  107. with capture_internal_exceptions():
  108. _capture_exception(e)
  109. span.__exit__(*exc_info)
  110. reraise(*exc_info)
  111. # Output attributes
  112. finish_reason = None
  113. response_model = None
  114. response_text_buffer: "list[str]" = []
  115. tokens_used = 0
  116. tool_calls = None
  117. usage = None
  118. with capture_internal_exceptions():
  119. if isinstance(res, str) and res is not None:
  120. response_text_buffer.append(res)
  121. if hasattr(res, "generated_text") and res.generated_text is not None:
  122. response_text_buffer.append(res.generated_text)
  123. if hasattr(res, "model") and res.model is not None:
  124. response_model = res.model
  125. if hasattr(res, "details") and hasattr(res.details, "finish_reason"):
  126. finish_reason = res.details.finish_reason
  127. if (
  128. hasattr(res, "details")
  129. and hasattr(res.details, "generated_tokens")
  130. and res.details.generated_tokens is not None
  131. ):
  132. tokens_used = res.details.generated_tokens
  133. if hasattr(res, "usage") and res.usage is not None:
  134. usage = res.usage
  135. if hasattr(res, "choices") and res.choices is not None:
  136. for choice in res.choices:
  137. if hasattr(choice, "finish_reason"):
  138. finish_reason = choice.finish_reason
  139. if hasattr(choice, "message") and hasattr(
  140. choice.message, "tool_calls"
  141. ):
  142. tool_calls = choice.message.tool_calls
  143. if (
  144. hasattr(choice, "message")
  145. and hasattr(choice.message, "content")
  146. and choice.message.content is not None
  147. ):
  148. response_text_buffer.append(choice.message.content)
  149. if response_model is not None:
  150. span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
  151. if finish_reason is not None:
  152. set_data_normalized(
  153. span,
  154. SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
  155. finish_reason,
  156. )
  157. if should_send_default_pii() and integration.include_prompts:
  158. if tool_calls is not None and len(tool_calls) > 0:
  159. set_data_normalized(
  160. span,
  161. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  162. tool_calls,
  163. unpack=False,
  164. )
  165. if len(response_text_buffer) > 0:
  166. text_response = "".join(response_text_buffer)
  167. if text_response:
  168. set_data_normalized(
  169. span,
  170. SPANDATA.GEN_AI_RESPONSE_TEXT,
  171. text_response,
  172. )
  173. if usage is not None:
  174. record_token_usage(
  175. span,
  176. input_tokens=usage.prompt_tokens,
  177. output_tokens=usage.completion_tokens,
  178. total_tokens=usage.total_tokens,
  179. )
  180. elif tokens_used > 0:
  181. record_token_usage(
  182. span,
  183. total_tokens=tokens_used,
  184. )
  185. # If the response is not a generator (meaning a streaming response)
  186. # we are done and can return the response
  187. if not inspect.isgenerator(res):
  188. span.__exit__(None, None, None)
  189. return res
  190. if kwargs.get("details", False):
  191. # text-generation stream output
  192. def new_details_iterator() -> "Iterable[Any]":
  193. finish_reason = None
  194. response_text_buffer: "list[str]" = []
  195. tokens_used = 0
  196. with capture_internal_exceptions():
  197. for chunk in res:
  198. if (
  199. hasattr(chunk, "token")
  200. and hasattr(chunk.token, "text")
  201. and chunk.token.text is not None
  202. ):
  203. response_text_buffer.append(chunk.token.text)
  204. if hasattr(chunk, "details") and hasattr(
  205. chunk.details, "finish_reason"
  206. ):
  207. finish_reason = chunk.details.finish_reason
  208. if (
  209. hasattr(chunk, "details")
  210. and hasattr(chunk.details, "generated_tokens")
  211. and chunk.details.generated_tokens is not None
  212. ):
  213. tokens_used = chunk.details.generated_tokens
  214. yield chunk
  215. if finish_reason is not None:
  216. set_data_normalized(
  217. span,
  218. SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
  219. finish_reason,
  220. )
  221. if should_send_default_pii() and integration.include_prompts:
  222. if len(response_text_buffer) > 0:
  223. text_response = "".join(response_text_buffer)
  224. if text_response:
  225. set_data_normalized(
  226. span,
  227. SPANDATA.GEN_AI_RESPONSE_TEXT,
  228. text_response,
  229. )
  230. if tokens_used > 0:
  231. record_token_usage(
  232. span,
  233. total_tokens=tokens_used,
  234. )
  235. span.__exit__(None, None, None)
  236. return new_details_iterator()
  237. else:
  238. # chat-completion stream output
  239. def new_iterator() -> "Iterable[str]":
  240. finish_reason = None
  241. response_model = None
  242. response_text_buffer: "list[str]" = []
  243. tool_calls = None
  244. usage = None
  245. with capture_internal_exceptions():
  246. for chunk in res:
  247. if hasattr(chunk, "model") and chunk.model is not None:
  248. response_model = chunk.model
  249. if hasattr(chunk, "usage") and chunk.usage is not None:
  250. usage = chunk.usage
  251. if isinstance(chunk, str):
  252. if chunk is not None:
  253. response_text_buffer.append(chunk)
  254. if hasattr(chunk, "choices") and chunk.choices is not None:
  255. for choice in chunk.choices:
  256. if (
  257. hasattr(choice, "delta")
  258. and hasattr(choice.delta, "content")
  259. and choice.delta.content is not None
  260. ):
  261. response_text_buffer.append(
  262. choice.delta.content
  263. )
  264. if (
  265. hasattr(choice, "finish_reason")
  266. and choice.finish_reason is not None
  267. ):
  268. finish_reason = choice.finish_reason
  269. if (
  270. hasattr(choice, "delta")
  271. and hasattr(choice.delta, "tool_calls")
  272. and choice.delta.tool_calls is not None
  273. ):
  274. tool_calls = choice.delta.tool_calls
  275. yield chunk
  276. if response_model is not None:
  277. span.set_data(
  278. SPANDATA.GEN_AI_RESPONSE_MODEL, response_model
  279. )
  280. if finish_reason is not None:
  281. set_data_normalized(
  282. span,
  283. SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
  284. finish_reason,
  285. )
  286. if should_send_default_pii() and integration.include_prompts:
  287. if tool_calls is not None and len(tool_calls) > 0:
  288. set_data_normalized(
  289. span,
  290. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  291. tool_calls,
  292. unpack=False,
  293. )
  294. if len(response_text_buffer) > 0:
  295. text_response = "".join(response_text_buffer)
  296. if text_response:
  297. set_data_normalized(
  298. span,
  299. SPANDATA.GEN_AI_RESPONSE_TEXT,
  300. text_response,
  301. )
  302. if usage is not None:
  303. record_token_usage(
  304. span,
  305. input_tokens=usage.prompt_tokens,
  306. output_tokens=usage.completion_tokens,
  307. total_tokens=usage.total_tokens,
  308. )
  309. span.__exit__(None, None, None)
  310. return new_iterator()
  311. return new_huggingface_task