| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- from __future__ import annotations
- import asyncio
- import functools
- import inspect
- import logging
- from collections.abc import Sequence
- from typing import Any, Optional, Protocol, TypeVar
- import wandb.sdk
- import wandb.util
- from wandb.sdk.lib import telemetry as wb_telemetry
- from wandb.sdk.lib.timer import Timer
- logger = logging.getLogger(__name__)
- AutologInitArgs = Optional[dict[str, Any]]
- K = TypeVar("K", bound=str)
- V = TypeVar("V")
- class Response(Protocol[K, V]):
- def __getitem__(self, key: K) -> V: ... # pragma: no cover
- def get(self, key: K, default: V | None = None) -> V | None: ... # pragma: no cover
- class ArgumentResponseResolver(Protocol):
- def __call__(
- self,
- args: Sequence[Any],
- kwargs: dict[str, Any],
- response: Response,
- start_time: float,
- time_elapsed: float,
- ) -> dict[str, Any] | None: ... # pragma: no cover
- class PatchAPI:
- def __init__(
- self,
- name: str,
- symbols: Sequence[str],
- resolver: ArgumentResponseResolver,
- ) -> None:
- """Patches the API to log wandb Media or metrics."""
- # name of the LLM provider, e.g. "Cohere" or "OpenAI" or package name like "Transformers"
- self.name = name
- # api library name, e.g. "cohere" or "openai" or "transformers"
- self._api = None
- # dictionary of original methods
- self.original_methods: dict[str, Any] = {}
- # list of symbols to patch, e.g. ["Client.generate", "Edit.create"] or ["Pipeline.__call__"]
- self.symbols = symbols
- # resolver callable to convert args/response into a dictionary of wandb media objects or metrics
- self.resolver = resolver
- @property
- def set_api(self) -> Any:
- """Returns the API module."""
- lib_name = self.name.lower()
- if self._api is None:
- self._api = wandb.util.get_module(
- name=lib_name,
- required=f"To use the W&B {self.name} Autolog, "
- f"you need to have the `{lib_name}` python "
- f"package installed. Please install it with `pip install {lib_name}`.",
- lazy=False,
- )
- return self._api
- def patch(self, run: wandb.Run) -> None:
- """Patches the API to log media or metrics to W&B."""
- for symbol in self.symbols:
- # split on dots, e.g. "Client.generate" -> ["Client", "generate"]
- symbol_parts = symbol.split(".")
- # and get the attribute from the module
- original = functools.reduce(getattr, symbol_parts, self.set_api)
- def method_factory(original_method: Any):
- async def async_method(*args, **kwargs):
- future = asyncio.Future()
- async def callback(coro):
- try:
- result = await coro
- loggable_dict = self.resolver(
- args, kwargs, result, timer.start_time, timer.elapsed
- )
- if loggable_dict is not None:
- run.log(loggable_dict)
- future.set_result(result)
- except Exception as e:
- logger.warning(e)
- with Timer() as timer:
- coro = original_method(*args, **kwargs)
- asyncio.ensure_future(callback(coro))
- return await future
- def sync_method(*args, **kwargs):
- with Timer() as timer:
- result = original_method(*args, **kwargs)
- try:
- loggable_dict = self.resolver(
- args, kwargs, result, timer.start_time, timer.elapsed
- )
- if loggable_dict is not None:
- run.log(loggable_dict)
- except Exception as e:
- logger.warning(e)
- return result
- if inspect.iscoroutinefunction(original_method):
- return functools.wraps(original_method)(async_method)
- else:
- return functools.wraps(original_method)(sync_method)
- # save original method
- self.original_methods[symbol] = original
- # monkey patch the method
- if len(symbol_parts) == 1:
- setattr(self.set_api, symbol_parts[0], method_factory(original))
- else:
- setattr(
- functools.reduce(getattr, symbol_parts[:-1], self.set_api),
- symbol_parts[-1],
- method_factory(original),
- )
- def unpatch(self) -> None:
- """Unpatches the API."""
- for symbol, original in self.original_methods.items():
- # split on dots, e.g. "Client.generate" -> ["Client", "generate"]
- symbol_parts = symbol.split(".")
- # unpatch the method
- if len(symbol_parts) == 1:
- setattr(self.set_api, symbol_parts[0], original)
- else:
- setattr(
- functools.reduce(getattr, symbol_parts[:-1], self.set_api),
- symbol_parts[-1],
- original,
- )
- class AutologAPI:
- def __init__(
- self,
- name: str,
- symbols: Sequence[str],
- resolver: ArgumentResponseResolver,
- telemetry_feature: str | None = None,
- ) -> None:
- """Autolog API calls to W&B."""
- self._telemetry_feature = telemetry_feature
- self._patch_api = PatchAPI(
- name=name,
- symbols=symbols,
- resolver=resolver,
- )
- self._name = self._patch_api.name
- self._run: wandb.Run | None = None
- self.__run_created_by_autolog: bool = False
- @property
- def _is_enabled(self) -> bool:
- """Returns whether autologging is enabled."""
- return self._run is not None
- def __call__(self, init: AutologInitArgs = None) -> None:
- """Enable autologging."""
- self.enable(init=init)
- def _run_init(self, init: AutologInitArgs = None) -> None:
- """Handle wandb run initialization."""
- # - autolog(init: dict = {...}) calls wandb.init(**{...})
- # regardless of whether there is a wandb.run or not,
- # we only track if the run was created by autolog
- # - todo: autolog(init: dict | run = run) would use the user-provided run
- # - autolog() uses the wandb.run if there is one, otherwise it calls wandb.init()
- if init:
- _wandb_run = wandb.run
- # we delegate dealing with the init dict to wandb.init()
- self._run = wandb.init(**init)
- if _wandb_run != self._run:
- self.__run_created_by_autolog = True
- elif wandb.run is None:
- self._run = wandb.init()
- self.__run_created_by_autolog = True
- else:
- self._run = wandb.run
- def enable(self, init: AutologInitArgs = None) -> None:
- """Enable autologging.
- Args:
- init: Optional dictionary of arguments to pass to wandb.init().
- """
- if self._is_enabled:
- logger.info(
- f"{self._name} autologging is already enabled, disabling and re-enabling."
- )
- self.disable()
- logger.info(f"Enabling {self._name} autologging.")
- self._run_init(init=init)
- self._patch_api.patch(self._run)
- if self._telemetry_feature:
- with wb_telemetry.context(self._run) as tel:
- setattr(tel.feature, self._telemetry_feature, True)
- def disable(self) -> None:
- """Disable autologging."""
- if self._run is None:
- return
- logger.info(f"Disabling {self._name} autologging.")
- if self.__run_created_by_autolog:
- self._run.finish()
- self.__run_created_by_autolog = False
- self._run = None
- self._patch_api.unpatch()
|