cohere.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import sys
  2. from functools import wraps
  3. from sentry_sdk import consts
  4. from sentry_sdk.ai.monitoring import record_token_usage
  5. from sentry_sdk.consts import SPANDATA
  6. from sentry_sdk.ai.utils import set_data_normalized
  7. from typing import TYPE_CHECKING
  8. from sentry_sdk.tracing_utils import set_span_errored
  9. if TYPE_CHECKING:
  10. from typing import Any, Callable, Iterator
  11. from sentry_sdk.tracing import Span
  12. import sentry_sdk
  13. from sentry_sdk.scope import should_send_default_pii
  14. from sentry_sdk.integrations import DidNotEnable, Integration
  15. from sentry_sdk.utils import capture_internal_exceptions, event_from_exception, reraise
  16. try:
  17. from cohere.client import Client
  18. from cohere.base_client import BaseCohere
  19. from cohere import (
  20. ChatStreamEndEvent,
  21. NonStreamedChatResponse,
  22. )
  23. if TYPE_CHECKING:
  24. from cohere import StreamedChatResponse
  25. except ImportError:
  26. raise DidNotEnable("Cohere not installed")
  27. try:
  28. # cohere 5.9.3+
  29. from cohere import StreamEndStreamedChatResponse
  30. except ImportError:
  31. from cohere import StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse
  32. COLLECTED_CHAT_PARAMS = {
  33. "model": SPANDATA.AI_MODEL_ID,
  34. "k": SPANDATA.AI_TOP_K,
  35. "p": SPANDATA.AI_TOP_P,
  36. "seed": SPANDATA.AI_SEED,
  37. "frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY,
  38. "presence_penalty": SPANDATA.AI_PRESENCE_PENALTY,
  39. "raw_prompting": SPANDATA.AI_RAW_PROMPTING,
  40. }
  41. COLLECTED_PII_CHAT_PARAMS = {
  42. "tools": SPANDATA.AI_TOOLS,
  43. "preamble": SPANDATA.AI_PREAMBLE,
  44. }
  45. COLLECTED_CHAT_RESP_ATTRS = {
  46. "generation_id": SPANDATA.AI_GENERATION_ID,
  47. "is_search_required": SPANDATA.AI_SEARCH_REQUIRED,
  48. "finish_reason": SPANDATA.AI_FINISH_REASON,
  49. }
  50. COLLECTED_PII_CHAT_RESP_ATTRS = {
  51. "citations": SPANDATA.AI_CITATIONS,
  52. "documents": SPANDATA.AI_DOCUMENTS,
  53. "search_queries": SPANDATA.AI_SEARCH_QUERIES,
  54. "search_results": SPANDATA.AI_SEARCH_RESULTS,
  55. "tool_calls": SPANDATA.AI_TOOL_CALLS,
  56. }
  57. class CohereIntegration(Integration):
  58. identifier = "cohere"
  59. origin = f"auto.ai.{identifier}"
  60. def __init__(self: "CohereIntegration", include_prompts: bool = True) -> None:
  61. self.include_prompts = include_prompts
  62. @staticmethod
  63. def setup_once() -> None:
  64. BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False)
  65. Client.embed = _wrap_embed(Client.embed)
  66. BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True)
  67. def _capture_exception(exc: "Any") -> None:
  68. set_span_errored()
  69. event, hint = event_from_exception(
  70. exc,
  71. client_options=sentry_sdk.get_client().options,
  72. mechanism={"type": "cohere", "handled": False},
  73. )
  74. sentry_sdk.capture_event(event, hint=hint)
  75. def _wrap_chat(f: "Callable[..., Any]", streaming: bool) -> "Callable[..., Any]":
  76. def collect_chat_response_fields(
  77. span: "Span", res: "NonStreamedChatResponse", include_pii: bool
  78. ) -> None:
  79. if include_pii:
  80. if hasattr(res, "text"):
  81. set_data_normalized(
  82. span,
  83. SPANDATA.AI_RESPONSES,
  84. [res.text],
  85. )
  86. for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS:
  87. if hasattr(res, pii_attr):
  88. set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr))
  89. for attr in COLLECTED_CHAT_RESP_ATTRS:
  90. if hasattr(res, attr):
  91. set_data_normalized(span, "ai." + attr, getattr(res, attr))
  92. if hasattr(res, "meta"):
  93. if hasattr(res.meta, "billed_units"):
  94. record_token_usage(
  95. span,
  96. input_tokens=res.meta.billed_units.input_tokens,
  97. output_tokens=res.meta.billed_units.output_tokens,
  98. )
  99. elif hasattr(res.meta, "tokens"):
  100. record_token_usage(
  101. span,
  102. input_tokens=res.meta.tokens.input_tokens,
  103. output_tokens=res.meta.tokens.output_tokens,
  104. )
  105. if hasattr(res.meta, "warnings"):
  106. set_data_normalized(span, SPANDATA.AI_WARNINGS, res.meta.warnings)
  107. @wraps(f)
  108. def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
  109. integration = sentry_sdk.get_client().get_integration(CohereIntegration)
  110. if (
  111. integration is None
  112. or "message" not in kwargs
  113. or not isinstance(kwargs.get("message"), str)
  114. ):
  115. return f(*args, **kwargs)
  116. message = kwargs.get("message")
  117. span = sentry_sdk.start_span(
  118. op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE,
  119. name="cohere.client.Chat",
  120. origin=CohereIntegration.origin,
  121. )
  122. span.__enter__()
  123. try:
  124. res = f(*args, **kwargs)
  125. except Exception as e:
  126. exc_info = sys.exc_info()
  127. with capture_internal_exceptions():
  128. _capture_exception(e)
  129. span.__exit__(None, None, None)
  130. reraise(*exc_info)
  131. with capture_internal_exceptions():
  132. if should_send_default_pii() and integration.include_prompts:
  133. set_data_normalized(
  134. span,
  135. SPANDATA.AI_INPUT_MESSAGES,
  136. list(
  137. map(
  138. lambda x: {
  139. "role": getattr(x, "role", "").lower(),
  140. "content": getattr(x, "message", ""),
  141. },
  142. kwargs.get("chat_history", []),
  143. )
  144. )
  145. + [{"role": "user", "content": message}],
  146. )
  147. for k, v in COLLECTED_PII_CHAT_PARAMS.items():
  148. if k in kwargs:
  149. set_data_normalized(span, v, kwargs[k])
  150. for k, v in COLLECTED_CHAT_PARAMS.items():
  151. if k in kwargs:
  152. set_data_normalized(span, v, kwargs[k])
  153. set_data_normalized(span, SPANDATA.AI_STREAMING, False)
  154. if streaming:
  155. old_iterator = res
  156. def new_iterator() -> "Iterator[StreamedChatResponse]":
  157. with capture_internal_exceptions():
  158. for x in old_iterator:
  159. if isinstance(x, ChatStreamEndEvent) or isinstance(
  160. x, StreamEndStreamedChatResponse
  161. ):
  162. collect_chat_response_fields(
  163. span,
  164. x.response,
  165. include_pii=should_send_default_pii()
  166. and integration.include_prompts,
  167. )
  168. yield x
  169. span.__exit__(None, None, None)
  170. return new_iterator()
  171. elif isinstance(res, NonStreamedChatResponse):
  172. collect_chat_response_fields(
  173. span,
  174. res,
  175. include_pii=should_send_default_pii()
  176. and integration.include_prompts,
  177. )
  178. span.__exit__(None, None, None)
  179. else:
  180. set_data_normalized(span, "unknown_response", True)
  181. span.__exit__(None, None, None)
  182. return res
  183. return new_chat
  184. def _wrap_embed(f: "Callable[..., Any]") -> "Callable[..., Any]":
  185. @wraps(f)
  186. def new_embed(*args: "Any", **kwargs: "Any") -> "Any":
  187. integration = sentry_sdk.get_client().get_integration(CohereIntegration)
  188. if integration is None:
  189. return f(*args, **kwargs)
  190. with sentry_sdk.start_span(
  191. op=consts.OP.COHERE_EMBEDDINGS_CREATE,
  192. name="Cohere Embedding Creation",
  193. origin=CohereIntegration.origin,
  194. ) as span:
  195. if "texts" in kwargs and (
  196. should_send_default_pii() and integration.include_prompts
  197. ):
  198. if isinstance(kwargs["texts"], str):
  199. set_data_normalized(span, SPANDATA.AI_TEXTS, [kwargs["texts"]])
  200. elif (
  201. isinstance(kwargs["texts"], list)
  202. and len(kwargs["texts"]) > 0
  203. and isinstance(kwargs["texts"][0], str)
  204. ):
  205. set_data_normalized(
  206. span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"]
  207. )
  208. if "model" in kwargs:
  209. set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"])
  210. try:
  211. res = f(*args, **kwargs)
  212. except Exception as e:
  213. exc_info = sys.exc_info()
  214. with capture_internal_exceptions():
  215. _capture_exception(e)
  216. reraise(*exc_info)
  217. if (
  218. hasattr(res, "meta")
  219. and hasattr(res.meta, "billed_units")
  220. and hasattr(res.meta.billed_units, "input_tokens")
  221. ):
  222. record_token_usage(
  223. span,
  224. input_tokens=res.meta.billed_units.input_tokens,
  225. total_tokens=res.meta.billed_units.input_tokens,
  226. )
  227. return res
  228. return new_embed