langgraph.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. from functools import wraps
  2. from typing import Any, Callable, List, Optional
  3. import sentry_sdk
  4. from sentry_sdk.ai.utils import (
  5. set_data_normalized,
  6. normalize_message_roles,
  7. truncate_and_annotate_messages,
  8. )
  9. from sentry_sdk.consts import OP, SPANDATA
  10. from sentry_sdk.integrations import DidNotEnable, Integration
  11. from sentry_sdk.scope import should_send_default_pii
  12. from sentry_sdk.utils import safe_serialize
  13. try:
  14. from langgraph.graph import StateGraph
  15. from langgraph.pregel import Pregel
  16. except ImportError:
  17. raise DidNotEnable("langgraph not installed")
  18. class LanggraphIntegration(Integration):
  19. identifier = "langgraph"
  20. origin = f"auto.ai.{identifier}"
  21. def __init__(self: "LanggraphIntegration", include_prompts: bool = True) -> None:
  22. self.include_prompts = include_prompts
  23. @staticmethod
  24. def setup_once() -> None:
  25. # LangGraph lets users create agents using a StateGraph or the Functional API.
  26. # StateGraphs are then compiled to a CompiledStateGraph. Both CompiledStateGraph and
  27. # the functional API execute on a Pregel instance. Pregel is the runtime for the graph
  28. # and the invocation happens on Pregel, so patching the invoke methods takes care of both.
  29. # The streaming methods are not patched, because due to some internal reasons, LangGraph
  30. # will automatically patch the streaming methods to run through invoke, and by doing this
  31. # we prevent duplicate spans for invocations.
  32. StateGraph.compile = _wrap_state_graph_compile(StateGraph.compile)
  33. if hasattr(Pregel, "invoke"):
  34. Pregel.invoke = _wrap_pregel_invoke(Pregel.invoke)
  35. if hasattr(Pregel, "ainvoke"):
  36. Pregel.ainvoke = _wrap_pregel_ainvoke(Pregel.ainvoke)
  37. def _get_graph_name(graph_obj: "Any") -> "Optional[str]":
  38. for attr in ["name", "graph_name", "__name__", "_name"]:
  39. if hasattr(graph_obj, attr):
  40. name = getattr(graph_obj, attr)
  41. if name and isinstance(name, str):
  42. return name
  43. return None
  44. def _normalize_langgraph_message(message: "Any") -> "Any":
  45. if not hasattr(message, "content"):
  46. return None
  47. parsed = {"role": getattr(message, "type", None), "content": message.content}
  48. for attr in [
  49. "name",
  50. "tool_calls",
  51. "function_call",
  52. "tool_call_id",
  53. "response_metadata",
  54. ]:
  55. if hasattr(message, attr):
  56. value = getattr(message, attr)
  57. if value is not None:
  58. parsed[attr] = value
  59. return parsed
  60. def _parse_langgraph_messages(state: "Any") -> "Optional[List[Any]]":
  61. if not state:
  62. return None
  63. messages = None
  64. if isinstance(state, dict):
  65. messages = state.get("messages")
  66. elif hasattr(state, "messages"):
  67. messages = state.messages
  68. elif hasattr(state, "get") and callable(state.get):
  69. try:
  70. messages = state.get("messages")
  71. except Exception:
  72. pass
  73. if not messages or not isinstance(messages, (list, tuple)):
  74. return None
  75. normalized_messages = []
  76. for message in messages:
  77. try:
  78. normalized = _normalize_langgraph_message(message)
  79. if normalized:
  80. normalized_messages.append(normalized)
  81. except Exception:
  82. continue
  83. return normalized_messages if normalized_messages else None
  84. def _wrap_state_graph_compile(f: "Callable[..., Any]") -> "Callable[..., Any]":
  85. @wraps(f)
  86. def new_compile(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":
  87. integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
  88. if integration is None:
  89. return f(self, *args, **kwargs)
  90. with sentry_sdk.start_span(
  91. op=OP.GEN_AI_CREATE_AGENT,
  92. origin=LanggraphIntegration.origin,
  93. ) as span:
  94. compiled_graph = f(self, *args, **kwargs)
  95. compiled_graph_name = getattr(compiled_graph, "name", None)
  96. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "create_agent")
  97. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, compiled_graph_name)
  98. if compiled_graph_name:
  99. span.description = f"create_agent {compiled_graph_name}"
  100. else:
  101. span.description = "create_agent"
  102. if kwargs.get("model", None) is not None:
  103. span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, kwargs.get("model"))
  104. tools = None
  105. get_graph = getattr(compiled_graph, "get_graph", None)
  106. if get_graph and callable(get_graph):
  107. graph_obj = compiled_graph.get_graph()
  108. nodes = getattr(graph_obj, "nodes", None)
  109. if nodes and isinstance(nodes, dict):
  110. tools_node = nodes.get("tools")
  111. if tools_node:
  112. data = getattr(tools_node, "data", None)
  113. if data and hasattr(data, "tools_by_name"):
  114. tools = list(data.tools_by_name.keys())
  115. if tools is not None:
  116. span.set_data(SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, tools)
  117. return compiled_graph
  118. return new_compile
  119. def _wrap_pregel_invoke(f: "Callable[..., Any]") -> "Callable[..., Any]":
  120. @wraps(f)
  121. def new_invoke(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":
  122. integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
  123. if integration is None:
  124. return f(self, *args, **kwargs)
  125. graph_name = _get_graph_name(self)
  126. span_name = (
  127. f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
  128. )
  129. with sentry_sdk.start_span(
  130. op=OP.GEN_AI_INVOKE_AGENT,
  131. name=span_name,
  132. origin=LanggraphIntegration.origin,
  133. ) as span:
  134. if graph_name:
  135. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, graph_name)
  136. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, graph_name)
  137. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
  138. # Store input messages to later compare with output
  139. input_messages = None
  140. if (
  141. len(args) > 0
  142. and should_send_default_pii()
  143. and integration.include_prompts
  144. ):
  145. input_messages = _parse_langgraph_messages(args[0])
  146. if input_messages:
  147. normalized_input_messages = normalize_message_roles(input_messages)
  148. scope = sentry_sdk.get_current_scope()
  149. messages_data = truncate_and_annotate_messages(
  150. normalized_input_messages, span, scope
  151. )
  152. if messages_data is not None:
  153. set_data_normalized(
  154. span,
  155. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  156. messages_data,
  157. unpack=False,
  158. )
  159. result = f(self, *args, **kwargs)
  160. _set_response_attributes(span, input_messages, result, integration)
  161. return result
  162. return new_invoke
  163. def _wrap_pregel_ainvoke(f: "Callable[..., Any]") -> "Callable[..., Any]":
  164. @wraps(f)
  165. async def new_ainvoke(self: "Any", *args: "Any", **kwargs: "Any") -> "Any":
  166. integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
  167. if integration is None:
  168. return await f(self, *args, **kwargs)
  169. graph_name = _get_graph_name(self)
  170. span_name = (
  171. f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
  172. )
  173. with sentry_sdk.start_span(
  174. op=OP.GEN_AI_INVOKE_AGENT,
  175. name=span_name,
  176. origin=LanggraphIntegration.origin,
  177. ) as span:
  178. if graph_name:
  179. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, graph_name)
  180. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, graph_name)
  181. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
  182. input_messages = None
  183. if (
  184. len(args) > 0
  185. and should_send_default_pii()
  186. and integration.include_prompts
  187. ):
  188. input_messages = _parse_langgraph_messages(args[0])
  189. if input_messages:
  190. normalized_input_messages = normalize_message_roles(input_messages)
  191. scope = sentry_sdk.get_current_scope()
  192. messages_data = truncate_and_annotate_messages(
  193. normalized_input_messages, span, scope
  194. )
  195. if messages_data is not None:
  196. set_data_normalized(
  197. span,
  198. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  199. messages_data,
  200. unpack=False,
  201. )
  202. result = await f(self, *args, **kwargs)
  203. _set_response_attributes(span, input_messages, result, integration)
  204. return result
  205. return new_ainvoke
  206. def _get_new_messages(
  207. input_messages: "Optional[List[Any]]", output_messages: "Optional[List[Any]]"
  208. ) -> "Optional[List[Any]]":
  209. """Extract only the new messages added during this invocation."""
  210. if not output_messages:
  211. return None
  212. if not input_messages:
  213. return output_messages
  214. # only return the new messages, aka the output messages that are not in the input messages
  215. input_count = len(input_messages)
  216. new_messages = (
  217. output_messages[input_count:] if len(output_messages) > input_count else []
  218. )
  219. return new_messages if new_messages else None
  220. def _extract_llm_response_text(messages: "Optional[List[Any]]") -> "Optional[str]":
  221. if not messages:
  222. return None
  223. for message in reversed(messages):
  224. if isinstance(message, dict):
  225. role = message.get("role")
  226. if role in ["assistant", "ai"]:
  227. content = message.get("content")
  228. if content and isinstance(content, str):
  229. return content
  230. return None
  231. def _extract_tool_calls(messages: "Optional[List[Any]]") -> "Optional[List[Any]]":
  232. if not messages:
  233. return None
  234. tool_calls = []
  235. for message in messages:
  236. if isinstance(message, dict):
  237. msg_tool_calls = message.get("tool_calls")
  238. if msg_tool_calls and isinstance(msg_tool_calls, list):
  239. tool_calls.extend(msg_tool_calls)
  240. return tool_calls if tool_calls else None
  241. def _set_usage_data(span: "sentry_sdk.tracing.Span", messages: "Any") -> None:
  242. input_tokens = 0
  243. output_tokens = 0
  244. total_tokens = 0
  245. for message in messages:
  246. response_metadata = message.get("response_metadata")
  247. if response_metadata is None:
  248. continue
  249. token_usage = response_metadata.get("token_usage")
  250. if not token_usage:
  251. continue
  252. input_tokens += int(token_usage.get("prompt_tokens", 0))
  253. output_tokens += int(token_usage.get("completion_tokens", 0))
  254. total_tokens += int(token_usage.get("total_tokens", 0))
  255. if input_tokens > 0:
  256. span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
  257. if output_tokens > 0:
  258. span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
  259. if total_tokens > 0:
  260. span.set_data(
  261. SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS,
  262. total_tokens,
  263. )
  264. def _set_response_model_name(span: "sentry_sdk.tracing.Span", messages: "Any") -> None:
  265. if len(messages) == 0:
  266. return
  267. last_message = messages[-1]
  268. response_metadata = last_message.get("response_metadata")
  269. if response_metadata is None:
  270. return
  271. model_name = response_metadata.get("model_name")
  272. if model_name is None:
  273. return
  274. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, model_name)
  275. def _set_response_attributes(
  276. span: "Any",
  277. input_messages: "Optional[List[Any]]",
  278. result: "Any",
  279. integration: "LanggraphIntegration",
  280. ) -> None:
  281. parsed_response_messages = _parse_langgraph_messages(result)
  282. new_messages = _get_new_messages(input_messages, parsed_response_messages)
  283. if new_messages is None:
  284. return
  285. _set_usage_data(span, new_messages)
  286. _set_response_model_name(span, new_messages)
  287. if not (should_send_default_pii() and integration.include_prompts):
  288. return
  289. llm_response_text = _extract_llm_response_text(new_messages)
  290. if llm_response_text:
  291. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, llm_response_text)
  292. elif new_messages:
  293. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, new_messages)
  294. else:
  295. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, result)
  296. tool_calls = _extract_tool_calls(new_messages)
  297. if tool_calls:
  298. set_data_normalized(
  299. span,
  300. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  301. safe_serialize(tool_calls),
  302. unpack=False,
  303. )