langchain.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255
  1. import contextvars
  2. import itertools
  3. import sys
  4. import json
  5. import warnings
  6. from collections import OrderedDict
  7. from functools import wraps
  8. from typing import TYPE_CHECKING
  9. import sentry_sdk
  10. from sentry_sdk.ai.monitoring import set_ai_pipeline_name
  11. from sentry_sdk.ai.utils import (
  12. GEN_AI_ALLOWED_MESSAGE_ROLES,
  13. get_start_span_function,
  14. normalize_message_roles,
  15. set_data_normalized,
  16. truncate_and_annotate_messages,
  17. transform_content_part,
  18. )
  19. from sentry_sdk.consts import OP, SPANDATA
  20. from sentry_sdk.integrations import DidNotEnable, Integration
  21. from sentry_sdk.scope import should_send_default_pii
  22. from sentry_sdk.tracing_utils import _get_value, set_span_errored
  23. from sentry_sdk.utils import capture_internal_exceptions, logger
  24. if TYPE_CHECKING:
  25. from typing import (
  26. Any,
  27. AsyncIterator,
  28. Callable,
  29. Dict,
  30. Iterator,
  31. List,
  32. Optional,
  33. Union,
  34. )
  35. from uuid import UUID
  36. from sentry_sdk.tracing import Span
  37. from sentry_sdk._types import TextPart
  38. try:
  39. from langchain_core.agents import AgentFinish
  40. from langchain_core.callbacks import (
  41. BaseCallbackHandler,
  42. BaseCallbackManager,
  43. Callbacks,
  44. manager,
  45. )
  46. from langchain_core.messages import BaseMessage
  47. from langchain_core.outputs import LLMResult
  48. except ImportError:
  49. raise DidNotEnable("langchain not installed")
  50. try:
  51. # >=v1
  52. from langchain_classic.agents import AgentExecutor # type: ignore[import-not-found]
  53. except ImportError:
  54. try:
  55. # <v1
  56. from langchain.agents import AgentExecutor
  57. except ImportError:
  58. AgentExecutor = None
  59. # Conditional imports for embeddings providers
  60. try:
  61. from langchain_openai import OpenAIEmbeddings # type: ignore[import-not-found]
  62. except ImportError:
  63. OpenAIEmbeddings = None
  64. try:
  65. from langchain_openai import AzureOpenAIEmbeddings
  66. except ImportError:
  67. AzureOpenAIEmbeddings = None
  68. try:
  69. from langchain_google_vertexai import VertexAIEmbeddings # type: ignore[import-not-found]
  70. except ImportError:
  71. VertexAIEmbeddings = None
  72. try:
  73. from langchain_aws import BedrockEmbeddings # type: ignore[import-not-found]
  74. except ImportError:
  75. BedrockEmbeddings = None
  76. try:
  77. from langchain_cohere import CohereEmbeddings # type: ignore[import-not-found]
  78. except ImportError:
  79. CohereEmbeddings = None
  80. try:
  81. from langchain_mistralai import MistralAIEmbeddings # type: ignore[import-not-found]
  82. except ImportError:
  83. MistralAIEmbeddings = None
  84. try:
  85. from langchain_huggingface import HuggingFaceEmbeddings # type: ignore[import-not-found]
  86. except ImportError:
  87. HuggingFaceEmbeddings = None
  88. try:
  89. from langchain_ollama import OllamaEmbeddings # type: ignore[import-not-found]
  90. except ImportError:
  91. OllamaEmbeddings = None
  92. def _get_ai_system(all_params: "Dict[str, Any]") -> "Optional[str]":
  93. ai_type = all_params.get("_type")
  94. if not ai_type or not isinstance(ai_type, str):
  95. return None
  96. return ai_type
  97. DATA_FIELDS = {
  98. "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
  99. "function_call": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  100. "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
  101. "presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
  102. "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
  103. "tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  104. "top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
  105. "top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
  106. }
  107. def _transform_langchain_content_block(
  108. content_block: "Dict[str, Any]",
  109. ) -> "Dict[str, Any]":
  110. """
  111. Transform a LangChain content block using the shared transform_content_part function.
  112. Returns the original content block if transformation is not applicable
  113. (e.g., for text blocks or unrecognized formats).
  114. """
  115. result = transform_content_part(content_block)
  116. return result if result is not None else content_block
  117. def _transform_langchain_message_content(content: "Any") -> "Any":
  118. """
  119. Transform LangChain message content, handling both string content and
  120. list of content blocks.
  121. """
  122. if isinstance(content, str):
  123. return content
  124. if isinstance(content, (list, tuple)):
  125. transformed = []
  126. for block in content:
  127. if isinstance(block, dict):
  128. transformed.append(_transform_langchain_content_block(block))
  129. else:
  130. transformed.append(block)
  131. return transformed
  132. return content
  133. # Contextvar to track agent names in a stack for re-entrant agent support
  134. _agent_stack: "contextvars.ContextVar[Optional[List[Optional[str]]]]" = (
  135. contextvars.ContextVar("langchain_agent_stack", default=None)
  136. )
  137. def _push_agent(agent_name: "Optional[str]") -> None:
  138. """Push an agent name onto the stack."""
  139. stack = _agent_stack.get()
  140. if stack is None:
  141. stack = []
  142. else:
  143. # Copy the list to maintain contextvar isolation across async contexts
  144. stack = stack.copy()
  145. stack.append(agent_name)
  146. _agent_stack.set(stack)
  147. def _pop_agent() -> "Optional[str]":
  148. """Pop an agent name from the stack and return it."""
  149. stack = _agent_stack.get()
  150. if stack:
  151. # Copy the list to maintain contextvar isolation across async contexts
  152. stack = stack.copy()
  153. agent_name = stack.pop()
  154. _agent_stack.set(stack)
  155. return agent_name
  156. return None
  157. def _get_current_agent() -> "Optional[str]":
  158. """Get the current agent name (top of stack) without removing it."""
  159. stack = _agent_stack.get()
  160. if stack:
  161. return stack[-1]
  162. return None
  163. def _get_system_instructions(messages: "List[List[BaseMessage]]") -> "List[str]":
  164. system_instructions = []
  165. for list_ in messages:
  166. for message in list_:
  167. # type of content: str | list[str | dict] | None
  168. if message.type == "system" and isinstance(message.content, str):
  169. system_instructions.append(message.content)
  170. elif message.type == "system" and isinstance(message.content, list):
  171. for item in message.content:
  172. if isinstance(item, str):
  173. system_instructions.append(item)
  174. elif isinstance(item, dict) and item.get("type") == "text":
  175. instruction = item.get("text")
  176. if isinstance(instruction, str):
  177. system_instructions.append(instruction)
  178. return system_instructions
  179. def _transform_system_instructions(
  180. system_instructions: "List[str]",
  181. ) -> "List[TextPart]":
  182. return [
  183. {
  184. "type": "text",
  185. "content": instruction,
  186. }
  187. for instruction in system_instructions
  188. ]
  189. class LangchainIntegration(Integration):
  190. identifier = "langchain"
  191. origin = f"auto.ai.{identifier}"
  192. def __init__(
  193. self: "LangchainIntegration",
  194. include_prompts: bool = True,
  195. max_spans: "Optional[int]" = None,
  196. ) -> None:
  197. self.include_prompts = include_prompts
  198. self.max_spans = max_spans
  199. if max_spans is not None:
  200. warnings.warn(
  201. "The `max_spans` parameter of `LangchainIntegration` is "
  202. "deprecated and will be removed in version 3.0 of sentry-sdk.",
  203. DeprecationWarning,
  204. stacklevel=2,
  205. )
  206. @staticmethod
  207. def setup_once() -> None:
  208. manager._configure = _wrap_configure(manager._configure)
  209. if AgentExecutor is not None:
  210. AgentExecutor.invoke = _wrap_agent_executor_invoke(AgentExecutor.invoke)
  211. AgentExecutor.stream = _wrap_agent_executor_stream(AgentExecutor.stream)
  212. # Patch embeddings providers
  213. _patch_embeddings_provider(OpenAIEmbeddings)
  214. _patch_embeddings_provider(AzureOpenAIEmbeddings)
  215. _patch_embeddings_provider(VertexAIEmbeddings)
  216. _patch_embeddings_provider(BedrockEmbeddings)
  217. _patch_embeddings_provider(CohereEmbeddings)
  218. _patch_embeddings_provider(MistralAIEmbeddings)
  219. _patch_embeddings_provider(HuggingFaceEmbeddings)
  220. _patch_embeddings_provider(OllamaEmbeddings)
  221. class WatchedSpan:
  222. span: "Span" = None # type: ignore[assignment]
  223. children: "List[WatchedSpan]" = []
  224. is_pipeline: bool = False
  225. def __init__(self, span: "Span") -> None:
  226. self.span = span
  227. class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]
  228. """Callback handler that creates Sentry spans."""
  229. def __init__(
  230. self, max_span_map_size: "Optional[int]", include_prompts: bool
  231. ) -> None:
  232. self.span_map: "OrderedDict[UUID, WatchedSpan]" = OrderedDict()
  233. self.max_span_map_size = max_span_map_size
  234. self.include_prompts = include_prompts
  235. def gc_span_map(self) -> None:
  236. if self.max_span_map_size is not None:
  237. while len(self.span_map) > self.max_span_map_size:
  238. run_id, watched_span = self.span_map.popitem(last=False)
  239. self._exit_span(watched_span, run_id)
  240. def _handle_error(self, run_id: "UUID", error: "Any") -> None:
  241. with capture_internal_exceptions():
  242. if not run_id or run_id not in self.span_map:
  243. return
  244. span_data = self.span_map[run_id]
  245. span = span_data.span
  246. set_span_errored(span)
  247. sentry_sdk.capture_exception(error, span.scope)
  248. span.__exit__(None, None, None)
  249. del self.span_map[run_id]
  250. def _normalize_langchain_message(self, message: "BaseMessage") -> "Any":
  251. # Transform content to handle multimodal data (images, audio, video, files)
  252. transformed_content = _transform_langchain_message_content(message.content)
  253. parsed = {"role": message.type, "content": transformed_content}
  254. parsed.update(message.additional_kwargs)
  255. return parsed
  256. def _create_span(
  257. self: "SentryLangchainCallback",
  258. run_id: "UUID",
  259. parent_id: "Optional[Any]",
  260. **kwargs: "Any",
  261. ) -> "WatchedSpan":
  262. watched_span: "Optional[WatchedSpan]" = None
  263. if parent_id:
  264. parent_span: "Optional[WatchedSpan]" = self.span_map.get(parent_id)
  265. if parent_span:
  266. watched_span = WatchedSpan(parent_span.span.start_child(**kwargs))
  267. parent_span.children.append(watched_span)
  268. if watched_span is None:
  269. watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs))
  270. watched_span.span.__enter__()
  271. self.span_map[run_id] = watched_span
  272. self.gc_span_map()
  273. return watched_span
  274. def _exit_span(
  275. self: "SentryLangchainCallback", span_data: "WatchedSpan", run_id: "UUID"
  276. ) -> None:
  277. if span_data.is_pipeline:
  278. set_ai_pipeline_name(None)
  279. span_data.span.__exit__(None, None, None)
  280. del self.span_map[run_id]
  281. def on_llm_start(
  282. self: "SentryLangchainCallback",
  283. serialized: "Dict[str, Any]",
  284. prompts: "List[str]",
  285. *,
  286. run_id: "UUID",
  287. tags: "Optional[List[str]]" = None,
  288. parent_run_id: "Optional[UUID]" = None,
  289. metadata: "Optional[Dict[str, Any]]" = None,
  290. **kwargs: "Any",
  291. ) -> "Any":
  292. with capture_internal_exceptions():
  293. if not run_id:
  294. return
  295. all_params = kwargs.get("invocation_params", {})
  296. all_params.update(serialized.get("kwargs", {}))
  297. model = (
  298. all_params.get("model")
  299. or all_params.get("model_name")
  300. or all_params.get("model_id")
  301. or ""
  302. )
  303. watched_span = self._create_span(
  304. run_id,
  305. parent_run_id,
  306. op=OP.GEN_AI_TEXT_COMPLETION,
  307. name=f"text_completion {model}".strip(),
  308. origin=LangchainIntegration.origin,
  309. )
  310. span = watched_span.span
  311. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "text_completion")
  312. pipeline_name = kwargs.get("name")
  313. if pipeline_name:
  314. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, pipeline_name)
  315. if model:
  316. span.set_data(
  317. SPANDATA.GEN_AI_REQUEST_MODEL,
  318. model,
  319. )
  320. ai_system = _get_ai_system(all_params)
  321. if ai_system:
  322. span.set_data(SPANDATA.GEN_AI_SYSTEM, ai_system)
  323. for key, attribute in DATA_FIELDS.items():
  324. if key in all_params and all_params[key] is not None:
  325. set_data_normalized(span, attribute, all_params[key], unpack=False)
  326. _set_tools_on_span(span, all_params.get("tools"))
  327. if should_send_default_pii() and self.include_prompts:
  328. normalized_messages = [
  329. {
  330. "role": GEN_AI_ALLOWED_MESSAGE_ROLES.USER,
  331. "content": {"type": "text", "text": prompt},
  332. }
  333. for prompt in prompts
  334. ]
  335. scope = sentry_sdk.get_current_scope()
  336. messages_data = truncate_and_annotate_messages(
  337. normalized_messages, span, scope
  338. )
  339. if messages_data is not None:
  340. set_data_normalized(
  341. span,
  342. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  343. messages_data,
  344. unpack=False,
  345. )
  346. def on_chat_model_start(
  347. self: "SentryLangchainCallback",
  348. serialized: "Dict[str, Any]",
  349. messages: "List[List[BaseMessage]]",
  350. *,
  351. run_id: "UUID",
  352. **kwargs: "Any",
  353. ) -> "Any":
  354. """Run when Chat Model starts running."""
  355. with capture_internal_exceptions():
  356. if not run_id:
  357. return
  358. all_params = kwargs.get("invocation_params", {})
  359. all_params.update(serialized.get("kwargs", {}))
  360. model = (
  361. all_params.get("model")
  362. or all_params.get("model_name")
  363. or all_params.get("model_id")
  364. or ""
  365. )
  366. watched_span = self._create_span(
  367. run_id,
  368. kwargs.get("parent_run_id"),
  369. op=OP.GEN_AI_CHAT,
  370. name=f"chat {model}".strip(),
  371. origin=LangchainIntegration.origin,
  372. )
  373. span = watched_span.span
  374. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
  375. if model:
  376. span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
  377. ai_system = _get_ai_system(all_params)
  378. if ai_system:
  379. span.set_data(SPANDATA.GEN_AI_SYSTEM, ai_system)
  380. agent_name = _get_current_agent()
  381. if agent_name:
  382. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)
  383. for key, attribute in DATA_FIELDS.items():
  384. if key in all_params and all_params[key] is not None:
  385. set_data_normalized(span, attribute, all_params[key], unpack=False)
  386. _set_tools_on_span(span, all_params.get("tools"))
  387. if should_send_default_pii() and self.include_prompts:
  388. system_instructions = _get_system_instructions(messages)
  389. if len(system_instructions) > 0:
  390. span.set_data(
  391. SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
  392. json.dumps(_transform_system_instructions(system_instructions)),
  393. )
  394. normalized_messages = []
  395. for list_ in messages:
  396. for message in list_:
  397. if message.type == "system":
  398. continue
  399. normalized_messages.append(
  400. self._normalize_langchain_message(message)
  401. )
  402. normalized_messages = normalize_message_roles(normalized_messages)
  403. scope = sentry_sdk.get_current_scope()
  404. messages_data = truncate_and_annotate_messages(
  405. normalized_messages, span, scope
  406. )
  407. if messages_data is not None:
  408. set_data_normalized(
  409. span,
  410. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  411. messages_data,
  412. unpack=False,
  413. )
  414. def on_chat_model_end(
  415. self: "SentryLangchainCallback",
  416. response: "LLMResult",
  417. *,
  418. run_id: "UUID",
  419. **kwargs: "Any",
  420. ) -> "Any":
  421. """Run when Chat Model ends running."""
  422. with capture_internal_exceptions():
  423. if not run_id or run_id not in self.span_map:
  424. return
  425. span_data = self.span_map[run_id]
  426. span = span_data.span
  427. if should_send_default_pii() and self.include_prompts:
  428. set_data_normalized(
  429. span,
  430. SPANDATA.GEN_AI_RESPONSE_TEXT,
  431. [[x.text for x in list_] for list_ in response.generations],
  432. )
  433. _record_token_usage(span, response)
  434. self._exit_span(span_data, run_id)
  435. def on_llm_end(
  436. self: "SentryLangchainCallback",
  437. response: "LLMResult",
  438. *,
  439. run_id: "UUID",
  440. **kwargs: "Any",
  441. ) -> "Any":
  442. """Run when LLM ends running."""
  443. with capture_internal_exceptions():
  444. if not run_id or run_id not in self.span_map:
  445. return
  446. span_data = self.span_map[run_id]
  447. span = span_data.span
  448. try:
  449. generation = response.generations[0][0]
  450. except IndexError:
  451. generation = None
  452. if generation is not None:
  453. try:
  454. response_model = generation.message.response_metadata.get(
  455. "model_name"
  456. )
  457. if response_model is not None:
  458. span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
  459. except AttributeError:
  460. pass
  461. try:
  462. finish_reason = generation.generation_info.get("finish_reason")
  463. if finish_reason is not None:
  464. span.set_data(
  465. SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
  466. [finish_reason],
  467. )
  468. except AttributeError:
  469. pass
  470. try:
  471. if should_send_default_pii() and self.include_prompts:
  472. tool_calls = getattr(generation.message, "tool_calls", None)
  473. if tool_calls is not None and tool_calls != []:
  474. set_data_normalized(
  475. span,
  476. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  477. tool_calls,
  478. unpack=False,
  479. )
  480. except AttributeError:
  481. pass
  482. if should_send_default_pii() and self.include_prompts:
  483. set_data_normalized(
  484. span,
  485. SPANDATA.GEN_AI_RESPONSE_TEXT,
  486. [[x.text for x in list_] for list_ in response.generations],
  487. )
  488. _record_token_usage(span, response)
  489. self._exit_span(span_data, run_id)
  490. def on_llm_error(
  491. self: "SentryLangchainCallback",
  492. error: "Union[Exception, KeyboardInterrupt]",
  493. *,
  494. run_id: "UUID",
  495. **kwargs: "Any",
  496. ) -> "Any":
  497. """Run when LLM errors."""
  498. self._handle_error(run_id, error)
  499. def on_chat_model_error(
  500. self: "SentryLangchainCallback",
  501. error: "Union[Exception, KeyboardInterrupt]",
  502. *,
  503. run_id: "UUID",
  504. **kwargs: "Any",
  505. ) -> "Any":
  506. """Run when Chat Model errors."""
  507. self._handle_error(run_id, error)
  508. def on_agent_finish(
  509. self: "SentryLangchainCallback",
  510. finish: "AgentFinish",
  511. *,
  512. run_id: "UUID",
  513. **kwargs: "Any",
  514. ) -> "Any":
  515. with capture_internal_exceptions():
  516. if not run_id or run_id not in self.span_map:
  517. return
  518. span_data = self.span_map[run_id]
  519. span = span_data.span
  520. if should_send_default_pii() and self.include_prompts:
  521. set_data_normalized(
  522. span, SPANDATA.GEN_AI_RESPONSE_TEXT, finish.return_values.items()
  523. )
  524. self._exit_span(span_data, run_id)
  525. def on_tool_start(
  526. self: "SentryLangchainCallback",
  527. serialized: "Dict[str, Any]",
  528. input_str: str,
  529. *,
  530. run_id: "UUID",
  531. **kwargs: "Any",
  532. ) -> "Any":
  533. """Run when tool starts running."""
  534. with capture_internal_exceptions():
  535. if not run_id:
  536. return
  537. tool_name = serialized.get("name") or kwargs.get("name") or ""
  538. watched_span = self._create_span(
  539. run_id,
  540. kwargs.get("parent_run_id"),
  541. op=OP.GEN_AI_EXECUTE_TOOL,
  542. name=f"execute_tool {tool_name}".strip(),
  543. origin=LangchainIntegration.origin,
  544. )
  545. span = watched_span.span
  546. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "execute_tool")
  547. span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool_name)
  548. tool_description = serialized.get("description")
  549. if tool_description is not None:
  550. span.set_data(SPANDATA.GEN_AI_TOOL_DESCRIPTION, tool_description)
  551. agent_name = _get_current_agent()
  552. if agent_name:
  553. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)
  554. if should_send_default_pii() and self.include_prompts:
  555. set_data_normalized(
  556. span,
  557. SPANDATA.GEN_AI_TOOL_INPUT,
  558. kwargs.get("inputs", [input_str]),
  559. )
  560. def on_tool_end(
  561. self: "SentryLangchainCallback", output: str, *, run_id: "UUID", **kwargs: "Any"
  562. ) -> "Any":
  563. """Run when tool ends running."""
  564. with capture_internal_exceptions():
  565. if not run_id or run_id not in self.span_map:
  566. return
  567. span_data = self.span_map[run_id]
  568. span = span_data.span
  569. if should_send_default_pii() and self.include_prompts:
  570. set_data_normalized(span, SPANDATA.GEN_AI_TOOL_OUTPUT, output)
  571. self._exit_span(span_data, run_id)
  572. def on_tool_error(
  573. self,
  574. error: "SentryLangchainCallback",
  575. *args: "Union[Exception, KeyboardInterrupt]",
  576. run_id: "UUID",
  577. **kwargs: "Any",
  578. ) -> "Any":
  579. """Run when tool errors."""
  580. self._handle_error(run_id, error)
  581. def _extract_tokens(
  582. token_usage: "Any",
  583. ) -> "tuple[Optional[int], Optional[int], Optional[int]]":
  584. if not token_usage:
  585. return None, None, None
  586. input_tokens = _get_value(token_usage, "prompt_tokens") or _get_value(
  587. token_usage, "input_tokens"
  588. )
  589. output_tokens = _get_value(token_usage, "completion_tokens") or _get_value(
  590. token_usage, "output_tokens"
  591. )
  592. total_tokens = _get_value(token_usage, "total_tokens")
  593. return input_tokens, output_tokens, total_tokens
  594. def _extract_tokens_from_generations(
  595. generations: "Any",
  596. ) -> "tuple[Optional[int], Optional[int], Optional[int]]":
  597. """Extract token usage from response.generations structure."""
  598. if not generations:
  599. return None, None, None
  600. total_input = 0
  601. total_output = 0
  602. total_total = 0
  603. for gen_list in generations:
  604. for gen in gen_list:
  605. token_usage = _get_token_usage(gen)
  606. input_tokens, output_tokens, total_tokens = _extract_tokens(token_usage)
  607. total_input += input_tokens if input_tokens is not None else 0
  608. total_output += output_tokens if output_tokens is not None else 0
  609. total_total += total_tokens if total_tokens is not None else 0
  610. return (
  611. total_input if total_input > 0 else None,
  612. total_output if total_output > 0 else None,
  613. total_total if total_total > 0 else None,
  614. )
  615. def _get_token_usage(obj: "Any") -> "Optional[Dict[str, Any]]":
  616. """
  617. Check multiple paths to extract token usage from different objects.
  618. """
  619. possible_names = ("usage", "token_usage", "usage_metadata")
  620. message = _get_value(obj, "message")
  621. if message is not None:
  622. for name in possible_names:
  623. usage = _get_value(message, name)
  624. if usage is not None:
  625. return usage
  626. llm_output = _get_value(obj, "llm_output")
  627. if llm_output is not None:
  628. for name in possible_names:
  629. usage = _get_value(llm_output, name)
  630. if usage is not None:
  631. return usage
  632. for name in possible_names:
  633. usage = _get_value(obj, name)
  634. if usage is not None:
  635. return usage
  636. return None
  637. def _record_token_usage(span: "Span", response: "Any") -> None:
  638. token_usage = _get_token_usage(response)
  639. if token_usage:
  640. input_tokens, output_tokens, total_tokens = _extract_tokens(token_usage)
  641. else:
  642. input_tokens, output_tokens, total_tokens = _extract_tokens_from_generations(
  643. response.generations
  644. )
  645. if input_tokens is not None:
  646. span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
  647. if output_tokens is not None:
  648. span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
  649. if total_tokens is not None:
  650. span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)
  651. def _get_request_data(
  652. obj: "Any", args: "Any", kwargs: "Any"
  653. ) -> "tuple[Optional[str], Optional[List[Any]]]":
  654. """
  655. Get the agent name and available tools for the agent.
  656. """
  657. agent = getattr(obj, "agent", None)
  658. runnable = getattr(agent, "runnable", None)
  659. runnable_config = getattr(runnable, "config", {})
  660. tools = (
  661. getattr(obj, "tools", None)
  662. or getattr(agent, "tools", None)
  663. or runnable_config.get("tools")
  664. or runnable_config.get("available_tools")
  665. )
  666. tools = tools if tools and len(tools) > 0 else None
  667. try:
  668. agent_name = None
  669. if len(args) > 1:
  670. agent_name = args[1].get("run_name")
  671. if agent_name is None:
  672. agent_name = runnable_config.get("run_name")
  673. except Exception:
  674. pass
  675. return (agent_name, tools)
  676. def _simplify_langchain_tools(tools: "Any") -> "Optional[List[Any]]":
  677. """Parse and simplify tools into a cleaner format."""
  678. if not tools:
  679. return None
  680. if not isinstance(tools, (list, tuple)):
  681. return None
  682. simplified_tools = []
  683. for tool in tools:
  684. try:
  685. if isinstance(tool, dict):
  686. if "function" in tool and isinstance(tool["function"], dict):
  687. func = tool["function"]
  688. simplified_tool = {
  689. "name": func.get("name"),
  690. "description": func.get("description"),
  691. }
  692. if simplified_tool["name"]:
  693. simplified_tools.append(simplified_tool)
  694. elif "name" in tool:
  695. simplified_tool = {
  696. "name": tool.get("name"),
  697. "description": tool.get("description"),
  698. }
  699. simplified_tools.append(simplified_tool)
  700. else:
  701. name = (
  702. tool.get("name")
  703. or tool.get("tool_name")
  704. or tool.get("function_name")
  705. )
  706. if name:
  707. simplified_tools.append(
  708. {
  709. "name": name,
  710. "description": tool.get("description")
  711. or tool.get("desc"),
  712. }
  713. )
  714. elif hasattr(tool, "name"):
  715. simplified_tool = {
  716. "name": getattr(tool, "name", None),
  717. "description": getattr(tool, "description", None)
  718. or getattr(tool, "desc", None),
  719. }
  720. if simplified_tool["name"]:
  721. simplified_tools.append(simplified_tool)
  722. elif hasattr(tool, "__name__"):
  723. simplified_tools.append(
  724. {
  725. "name": tool.__name__,
  726. "description": getattr(tool, "__doc__", None),
  727. }
  728. )
  729. else:
  730. tool_str = str(tool)
  731. if tool_str and tool_str != "":
  732. simplified_tools.append({"name": tool_str, "description": None})
  733. except Exception:
  734. continue
  735. return simplified_tools if simplified_tools else None
  736. def _set_tools_on_span(span: "Span", tools: "Any") -> None:
  737. """Set available tools data on a span if tools are provided."""
  738. if tools is not None:
  739. simplified_tools = _simplify_langchain_tools(tools)
  740. if simplified_tools:
  741. set_data_normalized(
  742. span,
  743. SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
  744. simplified_tools,
  745. unpack=False,
  746. )
  747. def _wrap_configure(f: "Callable[..., Any]") -> "Callable[..., Any]":
  748. @wraps(f)
  749. def new_configure(
  750. callback_manager_cls: type,
  751. inheritable_callbacks: "Callbacks" = None,
  752. local_callbacks: "Callbacks" = None,
  753. *args: "Any",
  754. **kwargs: "Any",
  755. ) -> "Any":
  756. integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
  757. if integration is None:
  758. return f(
  759. callback_manager_cls,
  760. inheritable_callbacks,
  761. local_callbacks,
  762. *args,
  763. **kwargs,
  764. )
  765. local_callbacks = local_callbacks or []
  766. # Handle each possible type of local_callbacks. For each type, we
  767. # extract the list of callbacks to check for SentryLangchainCallback,
  768. # and define a function that would add the SentryLangchainCallback
  769. # to the existing callbacks list.
  770. if isinstance(local_callbacks, BaseCallbackManager):
  771. callbacks_list = local_callbacks.handlers
  772. elif isinstance(local_callbacks, BaseCallbackHandler):
  773. callbacks_list = [local_callbacks]
  774. elif isinstance(local_callbacks, list):
  775. callbacks_list = local_callbacks
  776. else:
  777. logger.debug("Unknown callback type: %s", local_callbacks)
  778. # Just proceed with original function call
  779. return f(
  780. callback_manager_cls,
  781. inheritable_callbacks,
  782. local_callbacks,
  783. *args,
  784. **kwargs,
  785. )
  786. # Handle each possible type of inheritable_callbacks.
  787. if isinstance(inheritable_callbacks, BaseCallbackManager):
  788. inheritable_callbacks_list = inheritable_callbacks.handlers
  789. elif isinstance(inheritable_callbacks, list):
  790. inheritable_callbacks_list = inheritable_callbacks
  791. else:
  792. inheritable_callbacks_list = []
  793. if not any(
  794. isinstance(cb, SentryLangchainCallback)
  795. for cb in itertools.chain(callbacks_list, inheritable_callbacks_list)
  796. ):
  797. sentry_handler = SentryLangchainCallback(
  798. integration.max_spans,
  799. integration.include_prompts,
  800. )
  801. if isinstance(local_callbacks, BaseCallbackManager):
  802. local_callbacks = local_callbacks.copy()
  803. local_callbacks.handlers = [
  804. *local_callbacks.handlers,
  805. sentry_handler,
  806. ]
  807. elif isinstance(local_callbacks, BaseCallbackHandler):
  808. local_callbacks = [local_callbacks, sentry_handler]
  809. else:
  810. local_callbacks = [*local_callbacks, sentry_handler]
  811. return f(
  812. callback_manager_cls,
  813. inheritable_callbacks,
  814. local_callbacks,
  815. *args,
  816. **kwargs,
  817. )
  818. return new_configure
  819. def _wrap_agent_executor_invoke(f: "Callable[..., Any]") -> "Callable[..., Any]":
  820. @wraps(f)
  821. def new_invoke(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":
  822. integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
  823. if integration is None:
  824. return f(self, *args, **kwargs)
  825. agent_name, tools = _get_request_data(self, args, kwargs)
  826. start_span_function = get_start_span_function()
  827. with start_span_function(
  828. op=OP.GEN_AI_INVOKE_AGENT,
  829. name=f"invoke_agent {agent_name}" if agent_name else "invoke_agent",
  830. origin=LangchainIntegration.origin,
  831. ) as span:
  832. _push_agent(agent_name)
  833. try:
  834. if agent_name:
  835. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)
  836. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
  837. span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, False)
  838. _set_tools_on_span(span, tools)
  839. # Run the agent
  840. result = f(self, *args, **kwargs)
  841. input = result.get("input")
  842. if (
  843. input is not None
  844. and should_send_default_pii()
  845. and integration.include_prompts
  846. ):
  847. normalized_messages = normalize_message_roles([input])
  848. scope = sentry_sdk.get_current_scope()
  849. messages_data = truncate_and_annotate_messages(
  850. normalized_messages, span, scope
  851. )
  852. if messages_data is not None:
  853. set_data_normalized(
  854. span,
  855. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  856. messages_data,
  857. unpack=False,
  858. )
  859. output = result.get("output")
  860. if (
  861. output is not None
  862. and should_send_default_pii()
  863. and integration.include_prompts
  864. ):
  865. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
  866. return result
  867. finally:
  868. # Ensure agent is popped even if an exception occurs
  869. _pop_agent()
  870. return new_invoke
  871. def _wrap_agent_executor_stream(f: "Callable[..., Any]") -> "Callable[..., Any]":
  872. @wraps(f)
  873. def new_stream(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":
  874. integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
  875. if integration is None:
  876. return f(self, *args, **kwargs)
  877. agent_name, tools = _get_request_data(self, args, kwargs)
  878. start_span_function = get_start_span_function()
  879. span = start_span_function(
  880. op=OP.GEN_AI_INVOKE_AGENT,
  881. name=f"invoke_agent {agent_name}" if agent_name else "invoke_agent",
  882. origin=LangchainIntegration.origin,
  883. )
  884. span.__enter__()
  885. _push_agent(agent_name)
  886. if agent_name:
  887. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)
  888. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
  889. span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
  890. _set_tools_on_span(span, tools)
  891. input = args[0].get("input") if len(args) >= 1 else None
  892. if (
  893. input is not None
  894. and should_send_default_pii()
  895. and integration.include_prompts
  896. ):
  897. normalized_messages = normalize_message_roles([input])
  898. scope = sentry_sdk.get_current_scope()
  899. messages_data = truncate_and_annotate_messages(
  900. normalized_messages, span, scope
  901. )
  902. if messages_data is not None:
  903. set_data_normalized(
  904. span,
  905. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  906. messages_data,
  907. unpack=False,
  908. )
  909. # Run the agent
  910. result = f(self, *args, **kwargs)
  911. old_iterator = result
  912. def new_iterator() -> "Iterator[Any]":
  913. exc_info: "tuple[Any, Any, Any]" = (None, None, None)
  914. try:
  915. for event in old_iterator:
  916. yield event
  917. try:
  918. output = event.get("output")
  919. except Exception:
  920. output = None
  921. if (
  922. output is not None
  923. and should_send_default_pii()
  924. and integration.include_prompts
  925. ):
  926. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
  927. except Exception:
  928. exc_info = sys.exc_info()
  929. set_span_errored(span)
  930. raise
  931. finally:
  932. # Ensure cleanup happens even if iterator is abandoned or fails
  933. _pop_agent()
  934. span.__exit__(*exc_info)
  935. async def new_iterator_async() -> "AsyncIterator[Any]":
  936. exc_info: "tuple[Any, Any, Any]" = (None, None, None)
  937. try:
  938. async for event in old_iterator:
  939. yield event
  940. try:
  941. output = event.get("output")
  942. except Exception:
  943. output = None
  944. if (
  945. output is not None
  946. and should_send_default_pii()
  947. and integration.include_prompts
  948. ):
  949. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
  950. except Exception:
  951. exc_info = sys.exc_info()
  952. set_span_errored(span)
  953. raise
  954. finally:
  955. # Ensure cleanup happens even if iterator is abandoned or fails
  956. _pop_agent()
  957. span.__exit__(*exc_info)
  958. if str(type(result)) == "<class 'async_generator'>":
  959. result = new_iterator_async()
  960. else:
  961. result = new_iterator()
  962. return result
  963. return new_stream
  964. def _patch_embeddings_provider(provider_class: "Any") -> None:
  965. """Patch an embeddings provider class with monitoring wrappers."""
  966. if provider_class is None:
  967. return
  968. if hasattr(provider_class, "embed_documents"):
  969. provider_class.embed_documents = _wrap_embedding_method(
  970. provider_class.embed_documents
  971. )
  972. if hasattr(provider_class, "embed_query"):
  973. provider_class.embed_query = _wrap_embedding_method(provider_class.embed_query)
  974. if hasattr(provider_class, "aembed_documents"):
  975. provider_class.aembed_documents = _wrap_async_embedding_method(
  976. provider_class.aembed_documents
  977. )
  978. if hasattr(provider_class, "aembed_query"):
  979. provider_class.aembed_query = _wrap_async_embedding_method(
  980. provider_class.aembed_query
  981. )
  982. def _wrap_embedding_method(f: "Callable[..., Any]") -> "Callable[..., Any]":
  983. """Wrap sync embedding methods (embed_documents and embed_query)."""
  984. @wraps(f)
  985. def new_embedding_method(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":
  986. integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
  987. if integration is None:
  988. return f(self, *args, **kwargs)
  989. model_name = getattr(self, "model", None) or getattr(self, "model_name", None)
  990. with sentry_sdk.start_span(
  991. op=OP.GEN_AI_EMBEDDINGS,
  992. name=f"embeddings {model_name}" if model_name else "embeddings",
  993. origin=LangchainIntegration.origin,
  994. ) as span:
  995. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
  996. if model_name:
  997. span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
  998. # Capture input if PII is allowed
  999. if (
  1000. should_send_default_pii()
  1001. and integration.include_prompts
  1002. and len(args) > 0
  1003. ):
  1004. input_data = args[0]
  1005. # Normalize to list format
  1006. texts = input_data if isinstance(input_data, list) else [input_data]
  1007. set_data_normalized(
  1008. span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, texts, unpack=False
  1009. )
  1010. result = f(self, *args, **kwargs)
  1011. return result
  1012. return new_embedding_method
  1013. def _wrap_async_embedding_method(f: "Callable[..., Any]") -> "Callable[..., Any]":
  1014. """Wrap async embedding methods (aembed_documents and aembed_query)."""
  1015. @wraps(f)
  1016. async def new_async_embedding_method(
  1017. self: "Any", *args: "Any", **kwargs: "Any"
  1018. ) -> "Any":
  1019. integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
  1020. if integration is None:
  1021. return await f(self, *args, **kwargs)
  1022. model_name = getattr(self, "model", None) or getattr(self, "model_name", None)
  1023. with sentry_sdk.start_span(
  1024. op=OP.GEN_AI_EMBEDDINGS,
  1025. name=f"embeddings {model_name}" if model_name else "embeddings",
  1026. origin=LangchainIntegration.origin,
  1027. ) as span:
  1028. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
  1029. if model_name:
  1030. span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
  1031. # Capture input if PII is allowed
  1032. if (
  1033. should_send_default_pii()
  1034. and integration.include_prompts
  1035. and len(args) > 0
  1036. ):
  1037. input_data = args[0]
  1038. # Normalize to list format
  1039. texts = input_data if isinstance(input_data, list) else [input_data]
  1040. set_data_normalized(
  1041. span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, texts, unpack=False
  1042. )
  1043. result = await f(self, *args, **kwargs)
  1044. return result
  1045. return new_async_embedding_method