| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- from __future__ import annotations
- from contextvars import ContextVar
- from dataclasses import dataclass, field
- from functools import wraps
- from typing import Callable, Final, TypeVar
- from uuid import UUID, uuid4
- from typing_extensions import ParamSpec
- from wandb._strutils import nameof
- P = ParamSpec("P")
- R = TypeVar("R")
- # Header keys for tracking the calling function
- X_WANDB_PYTHON_FUNC: Final[str] = "X-Wandb-Python-Func"
- X_WANDB_PYTHON_CALL_ID: Final[str] = "X-Wandb-Python-Call-Id"
- @dataclass(frozen=True)
- class TrackedFuncInfo:
- func: str
- """The fully qualified namespace of the tracked function."""
- call_id: UUID = field(default_factory=uuid4)
- """A unique identifier assigned to each invocation."""
- def to_headers(self) -> dict[str, str]:
- return {
- X_WANDB_PYTHON_FUNC: self.func,
- X_WANDB_PYTHON_CALL_ID: str(self.call_id),
- }
- _current_func: ContextVar[TrackedFuncInfo] = ContextVar("_current_func")
- """An internal, threadsafe context variable to hold the current function being tracked."""
- def tracked(func: Callable[P, R]) -> Callable[P, R]:
- """A decorator to inject the calling function name into any GraphQL request headers.
- If a tracked function calls another tracked function, only the outermost function in
- the call stack will be tracked.
- """
- func_namespace = f"{func.__module__}.{nameof(func)}"
- @wraps(func)
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
- # Don't override the current tracked function if it's already set
- if tracked_func():
- return func(*args, **kwargs)
- token = _current_func.set(TrackedFuncInfo(func=func_namespace))
- try:
- return func(*args, **kwargs)
- finally:
- _current_func.reset(token)
- return wrapper
- def tracked_func() -> TrackedFuncInfo | None:
- """Returns info on the current tracked function, if any, otherwise None."""
- return _current_func.get(None)
|