streaming.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from typing import TYPE_CHECKING, Any, List, TypedDict, Optional
  2. from sentry_sdk.ai.utils import set_data_normalized
  3. from sentry_sdk.consts import SPANDATA
  4. from sentry_sdk.scope import should_send_default_pii
  5. from sentry_sdk.utils import (
  6. safe_serialize,
  7. )
  8. from .utils import (
  9. extract_tool_calls,
  10. extract_finish_reasons,
  11. extract_contents_text,
  12. extract_usage_data,
  13. UsageData,
  14. )
  15. if TYPE_CHECKING:
  16. from sentry_sdk.tracing import Span
  17. from google.genai.types import GenerateContentResponse
  18. class AccumulatedResponse(TypedDict):
  19. id: "Optional[str]"
  20. model: "Optional[str]"
  21. text: str
  22. finish_reasons: "List[str]"
  23. tool_calls: "List[dict[str, Any]]"
  24. usage_metadata: "Optional[UsageData]"
  25. def element_wise_usage_max(self: "UsageData", other: "UsageData") -> "UsageData":
  26. return UsageData(
  27. input_tokens=max(self["input_tokens"], other["input_tokens"]),
  28. output_tokens=max(self["output_tokens"], other["output_tokens"]),
  29. input_tokens_cached=max(
  30. self["input_tokens_cached"], other["input_tokens_cached"]
  31. ),
  32. output_tokens_reasoning=max(
  33. self["output_tokens_reasoning"], other["output_tokens_reasoning"]
  34. ),
  35. total_tokens=max(self["total_tokens"], other["total_tokens"]),
  36. )
  37. def accumulate_streaming_response(
  38. chunks: "List[GenerateContentResponse]",
  39. ) -> "AccumulatedResponse":
  40. """Accumulate streaming chunks into a single response-like object."""
  41. accumulated_text = []
  42. finish_reasons = []
  43. tool_calls = []
  44. usage_data = None
  45. response_id = None
  46. model = None
  47. for chunk in chunks:
  48. # Extract text and tool calls
  49. if getattr(chunk, "candidates", None):
  50. for candidate in getattr(chunk, "candidates", []):
  51. if hasattr(candidate, "content") and getattr(
  52. candidate.content, "parts", []
  53. ):
  54. extracted_text = extract_contents_text(candidate.content)
  55. if extracted_text:
  56. accumulated_text.append(extracted_text)
  57. extracted_finish_reasons = extract_finish_reasons(chunk)
  58. if extracted_finish_reasons:
  59. finish_reasons.extend(extracted_finish_reasons)
  60. extracted_tool_calls = extract_tool_calls(chunk)
  61. if extracted_tool_calls:
  62. tool_calls.extend(extracted_tool_calls)
  63. # Use last possible chunk, in case of interruption, and
  64. # gracefully handle missing intermediate tokens by taking maximum
  65. # with previous token reporting.
  66. chunk_usage_data = extract_usage_data(chunk)
  67. usage_data = (
  68. chunk_usage_data
  69. if usage_data is None
  70. else element_wise_usage_max(usage_data, chunk_usage_data)
  71. )
  72. accumulated_response = AccumulatedResponse(
  73. text="".join(accumulated_text),
  74. finish_reasons=finish_reasons,
  75. tool_calls=tool_calls,
  76. usage_metadata=usage_data,
  77. id=response_id,
  78. model=model,
  79. )
  80. return accumulated_response
  81. def set_span_data_for_streaming_response(
  82. span: "Span", integration: "Any", accumulated_response: "AccumulatedResponse"
  83. ) -> None:
  84. """Set span data for accumulated streaming response."""
  85. if (
  86. should_send_default_pii()
  87. and integration.include_prompts
  88. and accumulated_response.get("text")
  89. ):
  90. span.set_data(
  91. SPANDATA.GEN_AI_RESPONSE_TEXT,
  92. safe_serialize([accumulated_response["text"]]),
  93. )
  94. if accumulated_response.get("finish_reasons"):
  95. set_data_normalized(
  96. span,
  97. SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
  98. accumulated_response["finish_reasons"],
  99. )
  100. if accumulated_response.get("tool_calls"):
  101. span.set_data(
  102. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  103. safe_serialize(accumulated_response["tool_calls"]),
  104. )
  105. if accumulated_response.get("id"):
  106. span.set_data(SPANDATA.GEN_AI_RESPONSE_ID, accumulated_response["id"])
  107. if accumulated_response.get("model"):
  108. span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, accumulated_response["model"])
  109. if accumulated_response["usage_metadata"] is None:
  110. return
  111. if accumulated_response["usage_metadata"]["input_tokens"]:
  112. span.set_data(
  113. SPANDATA.GEN_AI_USAGE_INPUT_TOKENS,
  114. accumulated_response["usage_metadata"]["input_tokens"],
  115. )
  116. if accumulated_response["usage_metadata"]["input_tokens_cached"]:
  117. span.set_data(
  118. SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
  119. accumulated_response["usage_metadata"]["input_tokens_cached"],
  120. )
  121. if accumulated_response["usage_metadata"]["output_tokens"]:
  122. span.set_data(
  123. SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS,
  124. accumulated_response["usage_metadata"]["output_tokens"],
  125. )
  126. if accumulated_response["usage_metadata"]["output_tokens_reasoning"]:
  127. span.set_data(
  128. SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
  129. accumulated_response["usage_metadata"]["output_tokens_reasoning"],
  130. )
  131. if accumulated_response["usage_metadata"]["total_tokens"]:
  132. span.set_data(
  133. SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS,
  134. accumulated_response["usage_metadata"]["total_tokens"],
  135. )