monitoring.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import inspect
  2. import sys
  3. from functools import wraps
  4. from sentry_sdk.consts import SPANDATA
  5. import sentry_sdk.utils
  6. from sentry_sdk import start_span
  7. from sentry_sdk.tracing import Span
  8. from sentry_sdk.utils import ContextVar, reraise, capture_internal_exceptions
  9. from typing import TYPE_CHECKING
  10. if TYPE_CHECKING:
  11. from typing import Optional, Callable, Awaitable, Any, Union, TypeVar
  12. F = TypeVar("F", bound=Union[Callable[..., Any], Callable[..., Awaitable[Any]]])
  13. _ai_pipeline_name = ContextVar("ai_pipeline_name", default=None)
  14. def set_ai_pipeline_name(name: "Optional[str]") -> None:
  15. _ai_pipeline_name.set(name)
  16. def get_ai_pipeline_name() -> "Optional[str]":
  17. return _ai_pipeline_name.get()
  18. def ai_track(description: str, **span_kwargs: "Any") -> "Callable[[F], F]":
  19. def decorator(f: "F") -> "F":
  20. def sync_wrapped(*args: "Any", **kwargs: "Any") -> "Any":
  21. curr_pipeline = _ai_pipeline_name.get()
  22. op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
  23. with start_span(name=description, op=op, **span_kwargs) as span:
  24. for k, v in kwargs.pop("sentry_tags", {}).items():
  25. span.set_tag(k, v)
  26. for k, v in kwargs.pop("sentry_data", {}).items():
  27. span.set_data(k, v)
  28. if curr_pipeline:
  29. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
  30. return f(*args, **kwargs)
  31. else:
  32. _ai_pipeline_name.set(description)
  33. try:
  34. res = f(*args, **kwargs)
  35. except Exception as e:
  36. exc_info = sys.exc_info()
  37. with capture_internal_exceptions():
  38. event, hint = sentry_sdk.utils.event_from_exception(
  39. e,
  40. client_options=sentry_sdk.get_client().options,
  41. mechanism={"type": "ai_monitoring", "handled": False},
  42. )
  43. sentry_sdk.capture_event(event, hint=hint)
  44. reraise(*exc_info)
  45. finally:
  46. _ai_pipeline_name.set(None)
  47. return res
  48. async def async_wrapped(*args: "Any", **kwargs: "Any") -> "Any":
  49. curr_pipeline = _ai_pipeline_name.get()
  50. op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
  51. with start_span(name=description, op=op, **span_kwargs) as span:
  52. for k, v in kwargs.pop("sentry_tags", {}).items():
  53. span.set_tag(k, v)
  54. for k, v in kwargs.pop("sentry_data", {}).items():
  55. span.set_data(k, v)
  56. if curr_pipeline:
  57. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
  58. return await f(*args, **kwargs)
  59. else:
  60. _ai_pipeline_name.set(description)
  61. try:
  62. res = await f(*args, **kwargs)
  63. except Exception as e:
  64. exc_info = sys.exc_info()
  65. with capture_internal_exceptions():
  66. event, hint = sentry_sdk.utils.event_from_exception(
  67. e,
  68. client_options=sentry_sdk.get_client().options,
  69. mechanism={"type": "ai_monitoring", "handled": False},
  70. )
  71. sentry_sdk.capture_event(event, hint=hint)
  72. reraise(*exc_info)
  73. finally:
  74. _ai_pipeline_name.set(None)
  75. return res
  76. if inspect.iscoroutinefunction(f):
  77. return wraps(f)(async_wrapped) # type: ignore
  78. else:
  79. return wraps(f)(sync_wrapped) # type: ignore
  80. return decorator
  81. def record_token_usage(
  82. span: "Span",
  83. input_tokens: "Optional[int]" = None,
  84. input_tokens_cached: "Optional[int]" = None,
  85. input_tokens_cache_write: "Optional[int]" = None,
  86. output_tokens: "Optional[int]" = None,
  87. output_tokens_reasoning: "Optional[int]" = None,
  88. total_tokens: "Optional[int]" = None,
  89. ) -> None:
  90. # TODO: move pipeline name elsewhere
  91. ai_pipeline_name = get_ai_pipeline_name()
  92. if ai_pipeline_name:
  93. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, ai_pipeline_name)
  94. if input_tokens is not None:
  95. span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
  96. if input_tokens_cached is not None:
  97. span.set_data(
  98. SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
  99. input_tokens_cached,
  100. )
  101. if input_tokens_cache_write is not None:
  102. span.set_data(
  103. SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE,
  104. input_tokens_cache_write,
  105. )
  106. if output_tokens is not None:
  107. span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
  108. if output_tokens_reasoning is not None:
  109. span.set_data(
  110. SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
  111. output_tokens_reasoning,
  112. )
  113. if total_tokens is None and input_tokens is not None and output_tokens is not None:
  114. total_tokens = input_tokens + output_tokens
  115. if total_tokens is not None:
  116. span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)