history.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. """W&B Public API for Run History.
  2. This module provides classes for efficiently scanning and sampling run
  3. history data.
  4. Note:
  5. This module is part of the W&B Public API and provides methods
  6. to access run history data. It handles pagination automatically and offers
  7. both complete and sampled access to metrics logged during training runs.
  8. """
  9. from __future__ import annotations
  10. import contextlib
  11. import json
  12. import weakref
  13. from collections.abc import Iterator
  14. from typing import TYPE_CHECKING, Any
  15. from typing_extensions import Self, TypeAlias
  16. from wandb_gql import gql
  17. from wandb.apis.normalize import normalize_exceptions
  18. from wandb.apis.public.service_api import ServiceApi
  19. from wandb.proto import wandb_api_pb2 as pb
  20. from wandb.sdk.mailbox.mailbox import MailboxClosedError
  21. if TYPE_CHECKING:
  22. from . import runs
  23. from .api import RetryingClient
  24. _RowDict: TypeAlias = dict[str, Any]
  25. """Type alias for a single history row as a dict."""
  26. class BetaHistoryScan(Iterator[_RowDict]):
  27. """Iterator for scanning complete run history.
  28. <!-- lazydoc-ignore-class: internal -->
  29. """
  30. def __init__(
  31. self,
  32. service_api: ServiceApi,
  33. run: runs.Run,
  34. min_step: int,
  35. max_step: int,
  36. keys: list[str] | None = None,
  37. page_size: int = 1_000,
  38. use_cache: bool = True,
  39. ):
  40. self.run = run
  41. self.min_step = min_step
  42. self._stop_step = max_step
  43. self.keys = keys
  44. self.page_size = page_size
  45. self._service_api = service_api
  46. # Tell wandb-core to initialize resources to scan the run's history.
  47. scan_run_history_init = pb.ScanRunHistoryInit(
  48. entity=self.run.entity,
  49. project=self.run.project,
  50. run_id=self.run.id,
  51. keys=self.keys,
  52. use_cache=use_cache,
  53. )
  54. scan_run_history_init_request = pb.ReadRunHistoryRequest(
  55. scan_run_history_init=scan_run_history_init
  56. )
  57. api_request = pb.ApiRequest(
  58. read_run_history_request=scan_run_history_init_request
  59. )
  60. response: pb.ApiResponse = self._service_api.send_api_request(api_request)
  61. self._scan_request_id = (
  62. response.read_run_history_response.scan_run_history_init.request_id
  63. )
  64. self.scan_offset = 0
  65. self.rows: list[_RowDict] = []
  66. self.keys = keys
  67. # Add cleanup hook to clean up resources in wandb-core
  68. # when this scan object is deleted.
  69. #
  70. # Using weakref.finalize ensures that references to objects needed during cleanup
  71. # are not garbage collected before being used.
  72. # see: https://docs.python.org/3/library/weakref.html#comparing-finalizers-with-del-methods
  73. weakref.finalize(
  74. self,
  75. self.cleanup,
  76. self._service_api,
  77. self._scan_request_id,
  78. )
  79. @property
  80. def max_step(self) -> int:
  81. """The highest step that can be yielded by this scan."""
  82. return self._stop_step - 1
  83. def __iter__(self) -> Self:
  84. self.scan_offset = 0
  85. self.page_offset = self.min_step
  86. self.rows = []
  87. return self
  88. def __next__(self) -> _RowDict:
  89. while True:
  90. if self.scan_offset < len(self.rows):
  91. row = self.rows[self.scan_offset]
  92. self.scan_offset += 1
  93. return row
  94. if self.page_offset >= self._stop_step:
  95. raise StopIteration()
  96. # Load the next page
  97. self._load_next()
  98. # If no rows were returned, we've reached the end of the data
  99. if len(self.rows) == 0:
  100. raise StopIteration()
  101. def _load_next(self) -> None:
  102. from wandb.proto import wandb_api_pb2 as pb
  103. max_step = min(self.page_offset + self.page_size, self._stop_step)
  104. read_run_history_request = pb.ReadRunHistoryRequest(
  105. scan_run_history=pb.ScanRunHistory(
  106. min_step=self.page_offset,
  107. max_step=max_step,
  108. request_id=self._scan_request_id,
  109. ),
  110. )
  111. api_request = pb.ApiRequest(read_run_history_request=read_run_history_request)
  112. response: pb.ApiResponse = self._service_api.send_api_request(api_request)
  113. run_history: pb.RunHistoryResponse = (
  114. response.read_run_history_response.run_history
  115. )
  116. self.rows = [
  117. self._convert_history_row_to_dict(row) for row in run_history.history_rows
  118. ]
  119. self.page_offset += self.page_size
  120. self.scan_offset = 0
  121. @staticmethod
  122. def _convert_history_row_to_dict(history_row: pb.HistoryRow) -> _RowDict:
  123. return {
  124. item.key: json.loads(item.value_json) for item in history_row.history_items
  125. }
  126. @staticmethod
  127. def cleanup(service_api: ServiceApi, request_id: int) -> None:
  128. scan_run_history_cleanup = pb.ScanRunHistoryCleanup(
  129. request_id=request_id,
  130. )
  131. scan_run_history_cleanup_request = pb.ReadRunHistoryRequest(
  132. scan_run_history_cleanup=scan_run_history_cleanup
  133. )
  134. with contextlib.suppress(ConnectionResetError, MailboxClosedError):
  135. service_api.send_api_request(
  136. pb.ApiRequest(read_run_history_request=scan_run_history_cleanup_request)
  137. )
  138. class HistoryScan(Iterator[_RowDict]):
  139. """Iterator for scanning complete run history.
  140. <!-- lazydoc-ignore-class: internal -->
  141. """
  142. QUERY = gql(
  143. """
  144. query HistoryPage($entity: String!, $project: String!, $run: String!, $minStep: Int64!, $maxStep: Int64!, $pageSize: Int!) {
  145. project(name: $project, entityName: $entity) {
  146. run(name: $run) {
  147. history(minStep: $minStep, maxStep: $maxStep, samples: $pageSize)
  148. }
  149. }
  150. }
  151. """
  152. )
  153. def __init__(
  154. self,
  155. client: RetryingClient,
  156. run: runs.Run,
  157. min_step: int,
  158. max_step: int,
  159. page_size: int = 1_000,
  160. ):
  161. """Initialize a HistoryScan instance.
  162. Args:
  163. client: The client instance to use for making API calls to the W&B backend.
  164. run: The run object whose history is to be scanned.
  165. min_step: The minimum step to start scanning from.
  166. max_step: The exclusive upper bound for scanned history rows.
  167. page_size: Number of history rows to fetch per page.
  168. Default page_size is 1000.
  169. """
  170. self.client = client
  171. self.run = run
  172. self.page_size = page_size
  173. self.min_step = min_step
  174. self._stop_step = max_step
  175. self.page_offset = min_step # minStep for next page
  176. self.scan_offset = 0 # index within current page of rows
  177. self.rows: list[_RowDict] = [] # current page of rows
  178. @property
  179. def max_step(self) -> int:
  180. """The highest step that can be yielded by this scan."""
  181. return self._stop_step - 1
  182. def __iter__(self) -> Self:
  183. self.page_offset = self.min_step
  184. self.scan_offset = 0
  185. self.rows = []
  186. return self
  187. def __next__(self) -> _RowDict:
  188. """Return the next row of history data with automatic pagination.
  189. <!-- lazydoc-ignore: internal -->
  190. """
  191. while True:
  192. if self.scan_offset < len(self.rows):
  193. row = self.rows[self.scan_offset]
  194. self.scan_offset += 1
  195. return row
  196. if self.page_offset >= self._stop_step:
  197. raise StopIteration()
  198. self._load_next()
  199. next = __next__
  200. @normalize_exceptions
  201. def _load_next(self) -> None:
  202. max_step = self.page_offset + self.page_size
  203. if max_step > self._stop_step:
  204. max_step = self._stop_step
  205. variables = {
  206. "entity": self.run.entity,
  207. "project": self.run.project,
  208. "run": self.run.id,
  209. "minStep": int(self.page_offset),
  210. "maxStep": int(max_step),
  211. "pageSize": int(self.page_size),
  212. }
  213. res = self.client.execute(self.QUERY, variable_values=variables)
  214. res = res["project"]["run"]["history"]
  215. self.rows = [json.loads(row) for row in res]
  216. self.page_offset += self.page_size
  217. self.scan_offset = 0
  218. class SampledHistoryScan(Iterator[_RowDict]):
  219. """Iterator for sampling run history data.
  220. <!-- lazydoc-ignore-class: internal -->
  221. """
  222. QUERY = gql(
  223. """
  224. query SampledHistoryPage($entity: String!, $project: String!, $run: String!, $spec: JSONString!) {
  225. project(name: $project, entityName: $entity) {
  226. run(name: $run) {
  227. sampledHistory(specs: [$spec])
  228. }
  229. }
  230. }
  231. """
  232. )
  233. def __init__(
  234. self,
  235. client: RetryingClient,
  236. run: runs.Run,
  237. keys: list[str],
  238. min_step: int,
  239. max_step: int,
  240. page_size: int = 1_000,
  241. ):
  242. """Initialize a SampledHistoryScan instance.
  243. Args:
  244. client: The client instance to use for making API calls to the W&B backend.
  245. run: The run object whose history is to be sampled.
  246. keys: List of keys to sample from the history.
  247. min_step: The minimum step to start sampling from.
  248. max_step: The exclusive upper bound for sampled history rows.
  249. page_size: Number of sampled history rows to fetch per page.
  250. Default page_size is 1000.
  251. """
  252. self.client = client
  253. self.run = run
  254. self.keys = keys
  255. self.page_size = page_size
  256. self.min_step = min_step
  257. self._stop_step = max_step
  258. self.page_offset = min_step # minStep for next page
  259. self.scan_offset = 0 # index within current page of rows
  260. self.rows: list[_RowDict] = [] # current page of rows
  261. @property
  262. def max_step(self) -> int:
  263. """The highest step that can be yielded by this scan."""
  264. return self._stop_step - 1
  265. def __iter__(self) -> Self:
  266. self.page_offset = self.min_step
  267. self.scan_offset = 0
  268. self.rows = []
  269. return self
  270. def __next__(self) -> _RowDict:
  271. """Return the next row of sampled history data with automatic pagination.
  272. <!-- lazydoc-ignore: internal -->
  273. """
  274. while True:
  275. if self.scan_offset < len(self.rows):
  276. row = self.rows[self.scan_offset]
  277. self.scan_offset += 1
  278. return row
  279. if self.page_offset >= self._stop_step:
  280. raise StopIteration()
  281. self._load_next()
  282. next = __next__
  283. @normalize_exceptions
  284. def _load_next(self) -> None:
  285. max_step = self.page_offset + self.page_size
  286. if max_step > self._stop_step:
  287. max_step = self._stop_step
  288. variables = {
  289. "entity": self.run.entity,
  290. "project": self.run.project,
  291. "run": self.run.id,
  292. "spec": json.dumps(
  293. {
  294. "keys": self.keys,
  295. "minStep": int(self.page_offset),
  296. "maxStep": int(max_step),
  297. "samples": int(self.page_size),
  298. }
  299. ),
  300. }
  301. res = self.client.execute(self.QUERY, variable_values=variables)
  302. res = res["project"]["run"]["sampledHistory"]
  303. self.rows = res[0]
  304. self.page_offset += self.page_size
  305. self.scan_offset = 0