_analytics.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from __future__ import annotations
  2. from contextvars import ContextVar
  3. from dataclasses import dataclass, field
  4. from functools import wraps
  5. from typing import Callable, Final, TypeVar
  6. from uuid import UUID, uuid4
  7. from typing_extensions import ParamSpec
  8. from wandb._strutils import nameof
  9. P = ParamSpec("P")
  10. R = TypeVar("R")
  11. # Header keys for tracking the calling function
  12. X_WANDB_PYTHON_FUNC: Final[str] = "X-Wandb-Python-Func"
  13. X_WANDB_PYTHON_CALL_ID: Final[str] = "X-Wandb-Python-Call-Id"
  14. @dataclass(frozen=True)
  15. class TrackedFuncInfo:
  16. func: str
  17. """The fully qualified namespace of the tracked function."""
  18. call_id: UUID = field(default_factory=uuid4)
  19. """A unique identifier assigned to each invocation."""
  20. def to_headers(self) -> dict[str, str]:
  21. return {
  22. X_WANDB_PYTHON_FUNC: self.func,
  23. X_WANDB_PYTHON_CALL_ID: str(self.call_id),
  24. }
  25. _current_func: ContextVar[TrackedFuncInfo] = ContextVar("_current_func")
  26. """An internal, threadsafe context variable to hold the current function being tracked."""
  27. def tracked(func: Callable[P, R]) -> Callable[P, R]:
  28. """A decorator to inject the calling function name into any GraphQL request headers.
  29. If a tracked function calls another tracked function, only the outermost function in
  30. the call stack will be tracked.
  31. """
  32. func_namespace = f"{func.__module__}.{nameof(func)}"
  33. @wraps(func)
  34. def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
  35. # Don't override the current tracked function if it's already set
  36. if tracked_func():
  37. return func(*args, **kwargs)
  38. token = _current_func.set(TrackedFuncInfo(func=func_namespace))
  39. try:
  40. return func(*args, **kwargs)
  41. finally:
  42. _current_func.reset(token)
  43. return wrapper
  44. def tracked_func() -> TrackedFuncInfo | None:
  45. """Returns info on the current tracked function, if any, otherwise None."""
  46. return _current_func.get(None)