auto_logging.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. from __future__ import annotations
  2. import asyncio
  3. import functools
  4. import inspect
  5. import logging
  6. from collections.abc import Sequence
  7. from typing import Any, Optional, Protocol, TypeVar
  8. import wandb.sdk
  9. import wandb.util
  10. from wandb.sdk.lib import telemetry as wb_telemetry
  11. from wandb.sdk.lib.timer import Timer
  12. logger = logging.getLogger(__name__)
  13. AutologInitArgs = Optional[dict[str, Any]]
  14. K = TypeVar("K", bound=str)
  15. V = TypeVar("V")
  16. class Response(Protocol[K, V]):
  17. def __getitem__(self, key: K) -> V: ... # pragma: no cover
  18. def get(self, key: K, default: V | None = None) -> V | None: ... # pragma: no cover
  19. class ArgumentResponseResolver(Protocol):
  20. def __call__(
  21. self,
  22. args: Sequence[Any],
  23. kwargs: dict[str, Any],
  24. response: Response,
  25. start_time: float,
  26. time_elapsed: float,
  27. ) -> dict[str, Any] | None: ... # pragma: no cover
  28. class PatchAPI:
  29. def __init__(
  30. self,
  31. name: str,
  32. symbols: Sequence[str],
  33. resolver: ArgumentResponseResolver,
  34. ) -> None:
  35. """Patches the API to log wandb Media or metrics."""
  36. # name of the LLM provider, e.g. "Cohere" or "OpenAI" or package name like "Transformers"
  37. self.name = name
  38. # api library name, e.g. "cohere" or "openai" or "transformers"
  39. self._api = None
  40. # dictionary of original methods
  41. self.original_methods: dict[str, Any] = {}
  42. # list of symbols to patch, e.g. ["Client.generate", "Edit.create"] or ["Pipeline.__call__"]
  43. self.symbols = symbols
  44. # resolver callable to convert args/response into a dictionary of wandb media objects or metrics
  45. self.resolver = resolver
  46. @property
  47. def set_api(self) -> Any:
  48. """Returns the API module."""
  49. lib_name = self.name.lower()
  50. if self._api is None:
  51. self._api = wandb.util.get_module(
  52. name=lib_name,
  53. required=f"To use the W&B {self.name} Autolog, "
  54. f"you need to have the `{lib_name}` python "
  55. f"package installed. Please install it with `pip install {lib_name}`.",
  56. lazy=False,
  57. )
  58. return self._api
  59. def patch(self, run: wandb.Run) -> None:
  60. """Patches the API to log media or metrics to W&B."""
  61. for symbol in self.symbols:
  62. # split on dots, e.g. "Client.generate" -> ["Client", "generate"]
  63. symbol_parts = symbol.split(".")
  64. # and get the attribute from the module
  65. original = functools.reduce(getattr, symbol_parts, self.set_api)
  66. def method_factory(original_method: Any):
  67. async def async_method(*args, **kwargs):
  68. future = asyncio.Future()
  69. async def callback(coro):
  70. try:
  71. result = await coro
  72. loggable_dict = self.resolver(
  73. args, kwargs, result, timer.start_time, timer.elapsed
  74. )
  75. if loggable_dict is not None:
  76. run.log(loggable_dict)
  77. future.set_result(result)
  78. except Exception as e:
  79. logger.warning(e)
  80. with Timer() as timer:
  81. coro = original_method(*args, **kwargs)
  82. asyncio.ensure_future(callback(coro))
  83. return await future
  84. def sync_method(*args, **kwargs):
  85. with Timer() as timer:
  86. result = original_method(*args, **kwargs)
  87. try:
  88. loggable_dict = self.resolver(
  89. args, kwargs, result, timer.start_time, timer.elapsed
  90. )
  91. if loggable_dict is not None:
  92. run.log(loggable_dict)
  93. except Exception as e:
  94. logger.warning(e)
  95. return result
  96. if inspect.iscoroutinefunction(original_method):
  97. return functools.wraps(original_method)(async_method)
  98. else:
  99. return functools.wraps(original_method)(sync_method)
  100. # save original method
  101. self.original_methods[symbol] = original
  102. # monkey patch the method
  103. if len(symbol_parts) == 1:
  104. setattr(self.set_api, symbol_parts[0], method_factory(original))
  105. else:
  106. setattr(
  107. functools.reduce(getattr, symbol_parts[:-1], self.set_api),
  108. symbol_parts[-1],
  109. method_factory(original),
  110. )
  111. def unpatch(self) -> None:
  112. """Unpatches the API."""
  113. for symbol, original in self.original_methods.items():
  114. # split on dots, e.g. "Client.generate" -> ["Client", "generate"]
  115. symbol_parts = symbol.split(".")
  116. # unpatch the method
  117. if len(symbol_parts) == 1:
  118. setattr(self.set_api, symbol_parts[0], original)
  119. else:
  120. setattr(
  121. functools.reduce(getattr, symbol_parts[:-1], self.set_api),
  122. symbol_parts[-1],
  123. original,
  124. )
  125. class AutologAPI:
  126. def __init__(
  127. self,
  128. name: str,
  129. symbols: Sequence[str],
  130. resolver: ArgumentResponseResolver,
  131. telemetry_feature: str | None = None,
  132. ) -> None:
  133. """Autolog API calls to W&B."""
  134. self._telemetry_feature = telemetry_feature
  135. self._patch_api = PatchAPI(
  136. name=name,
  137. symbols=symbols,
  138. resolver=resolver,
  139. )
  140. self._name = self._patch_api.name
  141. self._run: wandb.Run | None = None
  142. self.__run_created_by_autolog: bool = False
  143. @property
  144. def _is_enabled(self) -> bool:
  145. """Returns whether autologging is enabled."""
  146. return self._run is not None
  147. def __call__(self, init: AutologInitArgs = None) -> None:
  148. """Enable autologging."""
  149. self.enable(init=init)
  150. def _run_init(self, init: AutologInitArgs = None) -> None:
  151. """Handle wandb run initialization."""
  152. # - autolog(init: dict = {...}) calls wandb.init(**{...})
  153. # regardless of whether there is a wandb.run or not,
  154. # we only track if the run was created by autolog
  155. # - todo: autolog(init: dict | run = run) would use the user-provided run
  156. # - autolog() uses the wandb.run if there is one, otherwise it calls wandb.init()
  157. if init:
  158. _wandb_run = wandb.run
  159. # we delegate dealing with the init dict to wandb.init()
  160. self._run = wandb.init(**init)
  161. if _wandb_run != self._run:
  162. self.__run_created_by_autolog = True
  163. elif wandb.run is None:
  164. self._run = wandb.init()
  165. self.__run_created_by_autolog = True
  166. else:
  167. self._run = wandb.run
  168. def enable(self, init: AutologInitArgs = None) -> None:
  169. """Enable autologging.
  170. Args:
  171. init: Optional dictionary of arguments to pass to wandb.init().
  172. """
  173. if self._is_enabled:
  174. logger.info(
  175. f"{self._name} autologging is already enabled, disabling and re-enabling."
  176. )
  177. self.disable()
  178. logger.info(f"Enabling {self._name} autologging.")
  179. self._run_init(init=init)
  180. self._patch_api.patch(self._run)
  181. if self._telemetry_feature:
  182. with wb_telemetry.context(self._run) as tel:
  183. setattr(tel.feature, self._telemetry_feature, True)
  184. def disable(self) -> None:
  185. """Disable autologging."""
  186. if self._run is None:
  187. return
  188. logger.info(f"Disabling {self._name} autologging.")
  189. if self.__run_created_by_autolog:
  190. self._run.finish()
  191. self.__run_created_by_autolog = False
  192. self._run = None
  193. self._patch_api.unpatch()